Ejemplo n.º 1
0
    def __init__(self):
        cfg = configs.parse_config()
        self.input_shape = (cfg.input_sizes[cfg.phi], cfg.input_sizes[cfg.phi])
        self.lr_first = cfg.learning_rate_first_stage
        self.Batch_size_first = cfg.Batch_size_first_stage
        self.Init_Epoch = cfg.Init_Epoch
        self.Freeze_Epoch = cfg.Freeze_Epoch
        self.opt_weight_decay = cfg.opt_weight_decay
        self.CosineAnnealingLR_T_max = cfg.CosineAnnealingLR_T_max
        self.CosineAnnealingLR_eta_min = cfg.CosineAnnealingLR_eta_min
        self.StepLR_step_size = cfg.StepLR_step_size
        self.StepLR_gamma = cfg.StepLR_gamma
        self.num_workers = cfg.num_workers
        self.Save_num_epoch = cfg.Save_num_epoch
        self.lr_second = cfg.learning_rate_second_stage
        self.Batch_size_second = cfg.Batch_size_second_stage
        self.Unfreeze_Epoch = cfg.Unfreeze_Epoch

        # TODO:tricks的使用设置
        self.Cosine_lr, self.mosaic = cfg.Cosine_lr, cfg.use_mosaic
        self.Cuda = torch.cuda.is_available()
        self.smoooth_label = cfg.smoooth_label
        self.Use_Data_Loader, self.annotation_path = cfg.Use_Data_Loader, cfg.train_annotation_path
        # TODO:获得类
        self.classes_path = cfg.classes_path
        self.class_names = self.get_classes(self.classes_path)
        self.num_classes = len(self.class_names)
        # TODO:创建模型
        self.model = EfficientDetBackbone(self.num_classes, cfg.phi)
        pretrain_weight_name = os.listdir(cfg.pretrain_dir)
        index = [item for item in pretrain_weight_name if str(cfg.phi) in item][0]
        # 加快模型训练的效率
        print('Loading pretrain_weights into state dict...')
        model_dict = self.model.state_dict()
        pretrained_dict = torch.load(cfg.pretrain_dir + index)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)
        print('Finished!')
        self.net = self.model.train()
        if self.Cuda:
            self.net = torch.nn.DataParallel(self.model)  # 多GPU进行训练,但这个设置有问题
            cudnn.benchmark = True
            self.net = self.net.cuda()

        # TODO:建立loss函数
        self.efficient_loss = FocalLoss()
        # cfg.val_split用于验证,1-cfg.val_split用于训练
        val_split = cfg.val_split
        with open(self.annotation_path) as f:
            self.lines = f.readlines()
        np.random.seed(101)
        np.random.shuffle(self.lines)
        np.random.seed(None)
        self.num_val = int(len(self.lines) * val_split)
        self.num_train = len(self.lines) - self.num_val

        self.train_first_stage()
        self.train_second_stage()
Ejemplo n.º 2
0
    print('加载与训练权重')
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('初始权重加载完成!')

    net = model.train()

    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()

    efficient_loss = FocalLoss()

    # 0.1用于验证,0.9用于训练
    val_split = 0.1
    with open(annotation_path) as f:
        lines = f.readlines()

    # 使用随机数种子保证再次启动训练时获得相同的训练和验证数据
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    num_val = int(len(lines) * val_split)
    num_train = len(lines) - num_val

    # ------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
Ejemplo n.º 3
0
        pretrained_dict = torch.load(model_path, map_location=device)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if np.shape(model_dict[k]) == np.shape(v)
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model_train = model.train()
    if Cuda:
        model_train = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        model_train = model_train.cuda()

    focal_loss = FocalLoss()
    loss_history = LossHistory("logs/")

    #---------------------------#
    #   读取数据集对应的txt
    #---------------------------#
    with open(train_annotation_path) as f:
        train_lines = f.readlines()
    with open(val_annotation_path) as f:
        val_lines = f.readlines()
    num_train = len(train_lines)
    num_val = len(val_lines)

    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。