Esempio n. 1
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            self._train_per_epoch(curr_epoch, train_loss_record)

            # 根据周期修改学习率
            if not self.arg_dict["sche_usebatch"]:
                self.sche.step()

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            save_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                amp=self.amp,
                exp_name=self.exp_name,
                current_epoch=curr_epoch + 1,
                full_net_path=self.path_dict["final_full_net"],
                state_net_path=self.path_dict["final_state_net"],
            )  # 保存参数

        if self.arg_dict["use_amp"]:
            # https://github.com/NVIDIA/apex/issues/567
            with self.amp.disable_casts():
                construct_print(
                    "When evaluating, we wish to evaluate in pure fp32.")
                self.test()
        else:
            self.test()
Esempio n. 2
0
    def __init__(self, root, in_size, training, prefix, use_bigt=False):
        self.training = training
        self.use_bigt = use_bigt

        if os.path.isdir(root):
            construct_print(f"{root} is an image folder, we will test on it.")
            self.imgs = _make_dataset(root)
        elif os.path.isfile(root):
            construct_print(
                f"{root} is a list of images, we will use these paths to read the "
                f"corresponding image")
            self.imgs = _make_dataset_from_list(root, prefix=prefix)
        else:
            raise NotImplementedError

        if self.training:
            self.joint_transform = Compose([
                JointResize(in_size),
                RandomHorizontallyFlip(),
                RandomRotate(10)
            ])
            img_transform = [transforms.ColorJitter(0.1, 0.1, 0.1)]
            self.mask_transform = transforms.ToTensor()
        else:
            # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
            img_transform = [
                transforms.Resize((in_size, in_size),
                                  interpolation=Image.BILINEAR),
            ]
        self.img_transform = transforms.Compose([
            *img_transform,
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    def __init__(self, root, in_size, prefix=(".jpg", ".png"), training=True):
        self.training = training

        if os.path.isdir(root):
            construct_print(f"{root} is an image folder")
            self.imgs = _make_dataset(root)
        elif os.path.isfile(root):
            construct_print(f"{root} is a list of images, we will read the corresponding image")
            self.imgs = _make_dataset_from_list(root, prefix=prefix)
        else:
            print(f"{root} is invalid")
            raise NotImplementedError

        if self.training:
            self.train_joint_transform = Compose([JointResize(in_size), RandomHorizontallyFlip(), RandomRotate(10)])
            self.train_img_transform = transforms.Compose(
                [
                    transforms.ColorJitter(0.1, 0.1, 0.1),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # 处理的是Tensor
                ]
            )
            self.train_mask_transform = transforms.ToTensor()
        else:
            self.test_img_trainsform = transforms.Compose(
                [
                    # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
                    transforms.Resize((in_size, in_size)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )
Esempio n. 4
0
    def __init__(self, root: dict, in_size: dict):
        super(TrainDataset, self).__init__()
        self.scale_list = [1.0]
        self.scale_list.extend(in_size['extra_scales'])

        self.total_image_paths = []
        self.total_flow_paths = []
        self.total_mask_paths = []
        self.video_name_list = []
        for root_name, root_item in root:
            new_data_dict = read_data_dict_from_dir(root_item,
                                                    data_set="train")
            construct_print(
                f"Loading data from {root_name}: {new_data_dict['root']}")
            self.total_image_paths += new_data_dict["jpeg"]
            self.total_flow_paths += new_data_dict['flow']
            self.total_mask_paths += new_data_dict["anno"]
            self.video_name_list += new_data_dict["name_list"]
        assert len(self.total_image_paths) == len(
            self.total_mask_paths) == len(self.total_flow_paths)

        h, w = get_hw(in_size)
        self.img_transform = A.RandomBrightnessContrast(brightness_limit=0.1,
                                                        contrast_limit=0.1,
                                                        p=0.5)
        self.joint_transform = A.Compose([
            A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR),
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=10, interpolation=cv2.INTER_LINEAR, p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
                                         additional_targets=dict(flow='image'))
Esempio n. 5
0
def create_loader(data_path,
                  mode,
                  get_length=False,
                  prefix=(".jpg", ".png"),
                  size_list=None):
    if mode == "train":
        construct_print(f"Training on: {data_path}")
        train_set = TrainImageFolder(
            data_path,
            in_size=arg_config["input_size"],
            prefix=prefix,
            use_bigt=arg_config["use_bigt"],
        )
        loader = _make_trloader(train_set,
                                shuffle=True,
                                drop_last=True,
                                size_list=size_list)
        length_of_dataset = len(train_set)
    elif mode == "test":
        construct_print(f"Testing on: {data_path}")
        test_set = TestImageFolder(data_path,
                                   in_size=arg_config["input_size"],
                                   prefix=prefix)
        loader = _make_teloader(test_set, shuffle=False, drop_last=False)
        length_of_dataset = len(test_set)
    else:
        raise NotImplementedError

    if get_length:
        return loader, length_of_dataset
    else:
        return loader
Esempio n. 6
0
 def wrapper(*args, **kwargs):
     start_time = datetime.now()
     construct_print(f"{cus_msg} start: {start_time}")
     results = func(*args, **kwargs)
     construct_print(
         f"the time of {cus_msg}: {datetime.now() - start_time}")
     return results
Esempio n. 7
0
def main():
    construct_print("We will test the model on one GPU.")
    model = getattr(network_lib, user_config["model"])().cuda()

    # resume model only to test model.
    resume_checkpoint(
        model=model, load_path=path_config["final_full_net"], mode="onlynet",
    )
    test(model)
Esempio n. 8
0
def train_epoch(data_loader, curr_epoch):
    loss_record = AvgMeter()
    construct_print(f"Exp_Name: {exp_name}")
    for curr_iter_in_epoch, data in enumerate(data_loader):
        num_iter_per_epoch = len(data_loader)
        curr_iter = curr_epoch * num_iter_per_epoch + curr_iter_in_epoch

        curr_jpegs = data["image"].to(DEVICES, non_blocking=True)
        curr_masks = data["mask"].to(DEVICES, non_blocking=True)
        curr_flows = data["flow"].to(DEVICES, non_blocking=True)
        with amp.autocast(enabled=arg_config['use_amp']):
            preds = model(
                data=dict(curr_jpeg=curr_jpegs, curr_flow=curr_flows))
            seg_jpeg_logits = preds["curr_seg"]
            seg_flow_logits = preds["curr_seg_flow"]

            jpeg_loss_info = cal_total_loss(seg_logits=seg_jpeg_logits,
                                            seg_gts=curr_masks)
            flow_loss_info = cal_total_loss(seg_logits=seg_flow_logits,
                                            seg_gts=curr_masks)
            total_loss = jpeg_loss_info['loss'] + flow_loss_info['loss']
            total_loss_str = jpeg_loss_info['loss_str'] + flow_loss_info[
                'loss_str']

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        iter_loss = total_loss.item()
        batch_size = curr_jpegs.size(0)
        loss_record.update(iter_loss, batch_size)

        if (arg_config["tb_update"] > 0
                and curr_iter % arg_config["tb_update"] == 0):
            tb_recorder.record_curve("loss_avg", loss_record.avg, curr_iter)
            tb_recorder.record_curve("iter_loss", iter_loss, curr_iter)
            tb_recorder.record_curve("lr", optimizer.param_groups, curr_iter)
            tb_recorder.record_image("jpegs", curr_jpegs, curr_iter)
            tb_recorder.record_image("flows", curr_flows, curr_iter)
            tb_recorder.record_image("masks", curr_masks, curr_iter)
            tb_recorder.record_image("segs", preds["curr_seg"].sigmoid(),
                                     curr_iter)

        if (arg_config["print_freq"] > 0
                and curr_iter % arg_config["print_freq"] == 0):
            lr_string = ",".join([f"{x:.7f}" for x in scheduler.get_last_lr()])
            log = (
                f"[{curr_iter_in_epoch}:{num_iter_per_epoch},{curr_iter}:{num_iter},"
                f"{curr_epoch}:{end_epoch}][{list(curr_jpegs.shape)}]"
                f"[Lr:{lr_string}][M:{loss_record.avg:.5f},C:{iter_loss:.5f}]{total_loss_str}"
            )
            print(log)
            make_log(path_dict["tr_log"], log)

        if scheduler_usebatch:
            scheduler.step()
Esempio n. 9
0
def _get_suffix(path_list):
    ext_list = list(set([os.path.splitext(p)[1] for p in path_list]))
    if len(ext_list) != 1:
        if ".png" in ext_list:
            ext = ".png"
        elif ".jpg" in ext_list:
            ext = ".jpg"
        elif ".bmp" in ext_list:
            ext = ".bmp"
        else:
            raise NotImplementedError
        construct_print(f"数据文件夹中包含多种扩展名,这里仅使用{ext}")
    else:
        ext = ext_list[0]
    return ext
Esempio n. 10
0
 def __init__(self, root, in_size, prefix):
     if os.path.isdir(root):
         construct_print(f"{root} is an image folder, we will test on it.")
         self.imgs = _make_dataset(root)
     elif os.path.isfile(root):
         construct_print(
             f"{root} is a list of images, we will use these paths to read the "
             f"corresponding image")
         self.imgs = _make_test_dataset_from_list(root, prefix=prefix)
     else:
         raise NotImplementedError
     self.test_img_trainsform = transforms.Compose([
         # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
         transforms.Resize((in_size, in_size)),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
     ])
Esempio n. 11
0
    def __init__(self, root: dict, in_size: dict):
        """
        :param root: 这里的root是实际对应的数据字典
        :param in_size:
        """
        new_data_dict = read_data_dict_from_dir(root, data_set="val")
        self.total_image_paths = new_data_dict["jpeg"]
        self.total_flow_paths = new_data_dict["flow"]
        self.total_mask_paths = new_data_dict["anno"]
        self.video_name_list = new_data_dict["name_list"]
        construct_print(f"Loading data from: {new_data_dict['root']}")
        assert len(self.total_image_paths) == len(
            self.total_mask_paths) == len(self.total_flow_paths)

        h, w = get_hw(in_size)
        self.transform = A.Compose([
            A.Resize(height=h, width=w, interpolation=cv2.INTER_LINEAR),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
                                   additional_targets=dict(flow='image'))
Esempio n. 12
0
    def resume_checkpoint(self, load_path, mode='all'):
        """
        从保存节点恢复模型

        Args:
            load_path (str): 模型存放路径
            mode (str): 选择哪种模型恢复模式:
                - 'all': 回复完整模型,包括训练中的的参数;
                - 'onlynet': 仅恢复模型权重参数
        """
        if os.path.exists(load_path) and os.path.isfile(load_path):
            construct_print(f"Loading checkpoint '{load_path}'")
            checkpoint = torch.load(load_path)
            if mode == 'all':
                if self.args["NET"] == checkpoint['arch']:
                    self.start_epoch = checkpoint['epoch']
                    self.net.load_state_dict(checkpoint['net_state'])
                    self.opti.load_state_dict(checkpoint['opti_state'])
                    construct_print(f"Loaded '{load_path}' "
                                    f"(epoch {checkpoint['epoch']})")
                else:
                    raise Exception(f"{load_path} does not match.")
            elif mode == 'onlynet':
                self.net.load_state_dict(checkpoint)
                construct_print(f"Loaded checkpoint '{load_path}' "
                                f"(only has the net's weight params)")
            else:
                raise NotImplementedError
        else:
            raise Exception(f"{load_path}路径不正常,请检查")
Esempio n. 13
0
def create_loader(data_path, mode, get_length=False, prefix=('.jpg', '.png')):
    length_of_dataset = 0

    if mode == 'train':
        construct_print(f"Training on: {data_path}")
        train_set = TrainImageFolder(data_path,
                                     in_size=arg_config["input_size"],
                                     prefix=prefix,
                                     use_bigt=arg_config['use_bigt'])
        loader = _make_loader(train_set, shuffle=True, drop_last=True)
        length_of_dataset = len(train_set)
    elif mode == 'test':
        if data_path is not None:
            construct_print(f"Testing on: {data_path}")
            test_set = TestImageFolder(data_path,
                                       in_size=arg_config["input_size"],
                                       prefix=prefix)
            loader = _make_loader(test_set, shuffle=False, drop_last=False)
            length_of_dataset = len(test_set)
        else:
            construct_print(f"No test...")
            loader = None
    else:
        raise NotImplementedError

    if get_length:
        return loader, length_of_dataset
    else:
        return loader
Esempio n. 14
0
    def test(self):
        self.net.eval()

        total_results = {}
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader = create_loader(
                data_path=data_path,
                mode="test",
                get_length=False,
                prefix=self.args["prefix"],
            )
            self.save_path = os.path.join(self.path["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(
                    f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self.__test_process(save_pre=self.save_pre)
            msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
            construct_print(msg)
            make_log(self.path["te_log"], msg)

            total_results[data_name.upper()] = results

        self.net.train()
        return total_results
Esempio n. 15
0
    def test(self):
        self.net.eval()

        total_results = {}
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader = create_loader(
                data_path=data_path,
                training=False,
                prefix=self.arg_dict["prefix"],
                get_length=False,
            )
            self.save_path = os.path.join(self.path_dict["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(
                    f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self._test_process(save_pre=self.save_pre)
            msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
            construct_print(msg)
            write_data_to_file(msg, self.path_dict["te_log"])

            total_results[data_name] = results

        self.net.train()

        if self.arg_dict["xlsx_name"]:
            # save result into xlsx file.
            self.xlsx_recorder.write_xlsx(self.exp_name, total_results)
Esempio n. 16
0
def main():
    check_mkdir(path_config["save"])
    check_mkdir(path_config["pth"])
    # shutil.copy(f"{user_config['proj_root']}/config.py", path_config["cfg_log"])
    # shutil.copy(f"{user_config['proj_root']}/train.py", path_config["trainer_log"])

    if user_config["is_distributed"]:
        construct_print("We will use the distributed training.")
        args.world_size = args.ngpus_per_node * args.nodes
        torch.multiprocessing.spawn(
            main_worker,
            nprocs=args.ngpus_per_node,
            args=(args.ngpus_per_node, args.world_size),
        )
    else:
        construct_print("We will not use the distributed training.")
        main_worker(
            local_rank=0,
            ngpus_per_node=1,
            world_size=1,
        )
    tb_recorder.close_tb()
Esempio n. 17
0
 def __init__(self, root, in_size, prefix, use_bigt=False):
     self.use_bigt = use_bigt
     if os.path.isdir(root):
         construct_print(f"{root} is an image folder, we will test on it.")
         self.imgs = _make_dataset(root)
     elif os.path.isfile(root):
         construct_print(
             f"{root} is a list of images, we will use these paths to read the "
             f"corresponding image")
         self.imgs = _make_train_dataset_from_list(root, prefix=prefix)
     else:
         raise NotImplementedError
     self.train_joint_transform = Compose(
         [JointResize(in_size),
          RandomHorizontallyFlip(),
          RandomRotate(10)])
     self.train_img_transform = transforms.Compose([
         transforms.ColorJitter(0.1, 0.1, 0.1),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406],
                              [0.229, 0.224, 0.225])  # 处理的是Tensor
     ])
     self.train_mask_transform = transforms.ToTensor()
Esempio n. 18
0
def create_loader(data_path,
                  training,
                  size_list=None,
                  prefix=(".jpg", ".png"),
                  get_length=False):
    if training:
        construct_print(f"Training on: {data_path}")
        imageset = ImageFolder(
            data_path,
            in_size=arg_config["input_size"],
            prefix=prefix,
            use_bigt=arg_config["use_bigt"],
            training=True,
        )
        loader = _mask_loader(imageset,
                              shuffle=True,
                              drop_last=True,
                              size_list=size_list)
    else:
        construct_print(f"Testing on: {data_path}")
        imageset = ImageFolder(
            data_path,
            in_size=arg_config["input_size"],
            prefix=prefix,
            training=False,
        )
        loader = _mask_loader(imageset,
                              shuffle=False,
                              drop_last=False,
                              size_list=None)

    if get_length:
        length_of_dataset = len(imageset)
        return loader, length_of_dataset
    else:
        return loader
Esempio n. 19
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数
            
    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name == checkpoint["arch"]:
                start_epoch = checkpoint["epoch"]
                model.load_state_dict(checkpoint["net_state"])
                optimizer.load_state_dict(checkpoint["opti_state"])
                scheduler.load_state_dict(checkpoint["sche_state"])
                construct_print(f"Loaded '{load_path}' "
                                f"(will train at epoch"
                                f" {checkpoint['epoch']})")
                return start_epoch
            else:
                raise Exception(f"{load_path} does not match.")
        elif mode == "onlynet":
            model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")
Esempio n. 20
0
def test(model, mode="test", save_pre=True):
    model.eval()

    test_dataset_dict = user_config["rgb_data"]["te_data_list"]
    if mode == "val":
        test_dataset_dict = user_config["rgb_data"]["val_data_path"]

    total_results = {}
    for idx, (data_name, data_path) in enumerate(test_dataset_dict.items()):
        construct_print(f"Testing on the dataset: {data_name}, {data_path}")
        test_set = ImageFolder(root=data_path,
                               in_size=user_config["input_size"],
                               training=False)
        length = len(test_set)
        te_loader = create_loader(
            data_set=test_set,
            size_list=None,
            batch_size=batch_size_single_gpu,
            shuffle=False,
            num_workers=user_config["num_workers"],
            sampler=None,
            drop_last=False,
            pin_memory=True,
        )
        save_path = os.path.join(path_config["save"], data_name)
        if not os.path.exists(save_path):
            construct_print(f"{save_path} do not exist. Let's create it.")
            os.makedirs(save_path)
        results = _test_process(
            model=model,
            length=length,
            te_loader=te_loader,
            save_pre=save_pre,
            save_path=save_path,
        )
        msg = f"Results on the {mode}set({data_name}:'{data_path}'):\n{results}"
        write_data_to_file(msg, path_config["te_log"])
        construct_print(msg)

        total_results[data_name.upper()] = results
    return total_results
Esempio n. 21
0
def resume_checkpoint(exp_name, load_path, model, optimizer=None, scheduler=None, scaler=None, mode="all",
                      force_load=False):
    """
    从保存节点恢复模型

    Args:
        load_path (str): 模型存放路径
        model: your model
        optimizer: your optimizer
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数, will return start_epoch;
            - 'onlynet': 仅恢复模型权重参数
    """
    assert os.path.exists(load_path) and os.path.isfile(load_path), load_path

    construct_print(f"Loading checkpoint '{load_path}'")
    checkpoint = torch.load(load_path)

    if mode == "all":
        assert (optimizer is not None) and (scheduler is not None)
        if exp_name == checkpoint["arch"] or force_load:
            start_epoch = checkpoint["epoch"]
            model.load_state_dict(checkpoint["net_state"])
            optimizer.load_state_dict(checkpoint["opti_state"])
            scheduler.load_state_dict(checkpoint["sche_state"])
            if scaler and checkpoint.get('scaler', None) is not None:
                scaler.load_state_dict(checkpoint["scaler"])
            construct_print(
                f"Loaded '{load_path}' " f"(epoch {checkpoint['epoch']})")
        else:
            raise Exception(f"{load_path} does not match.")
        return start_epoch
    elif mode == "onlynet":
        model.load_state_dict(checkpoint)
        construct_print(
            f"Loaded checkpoint '{load_path}' " f"(only has the net's weight "
            f"params)")
    else:
        raise NotImplementedError
Esempio n. 22
0
def test(model, data_loader, save_path=""):
    """
    为了计算方便,训练过程中的验证与测试都直接计算指标J和F,不再先生成再输出,
    所以这里的指标仅作一个相对的参考,具体真实指标需要使用测试代码处理
    """
    model.eval()
    tqdm_iter = tqdm(enumerate(data_loader),
                     total=len(data_loader),
                     leave=False)

    if arg_config['use_tta']:
        construct_print("We will use Test Time Augmentation!")
        transforms = tta.Compose([  # 2*3
            tta.HorizontalFlip(),
            tta.Scale(scales=[0.75, 1, 1.5],
                      interpolation='bilinear',
                      align_corners=False)
        ])
    else:
        transforms = None

    results = defaultdict(list)
    for test_batch_id, test_data in tqdm_iter:
        tqdm_iter.set_description(f"te=>{test_batch_id + 1}")

        with torch.no_grad():
            curr_jpegs = test_data["image"].to(DEVICES, non_blocking=True)
            curr_flows = test_data["flow"].to(DEVICES, non_blocking=True)
            preds_logits = tta_aug(model=model,
                                   transforms=transforms,
                                   data=dict(curr_jpeg=curr_jpegs,
                                             curr_flow=curr_flows))
            preds_prob = preds_logits.sigmoid().squeeze().cpu().detach(
            )  # float32

        for i, pred_prob in enumerate(preds_prob.numpy()):
            curr_mask_path = test_data["mask_path"][i]
            video_name, mask_name = curr_mask_path.split(os.sep)[-2:]
            mask = read_binary_array(curr_mask_path, thr=0)
            mask_h, mask_w = mask.shape

            pred_prob = cv2.resize(pred_prob,
                                   dsize=(mask_w, mask_h),
                                   interpolation=cv2.INTER_LINEAR)
            pred_prob = clip_to_normalize(data_array=pred_prob,
                                          clip_range=arg_config["clip_range"])
            pred_seg = np.where(pred_prob > 0.5, 255, 0).astype(np.uint8)

            results[video_name].append(
                (jaccard.db_eval_iou(annotation=mask, segmentation=pred_seg),
                 f_boundary.db_eval_boundary(annotation=mask,
                                             segmentation=pred_seg)))

            if save_path:
                pred_video_path = os.path.join(save_path, video_name)
                if not os.path.exists(pred_video_path):
                    os.makedirs(pred_video_path)
                pred_frame_path = os.path.join(pred_video_path, mask_name)
                cv2.imwrite(pred_frame_path, pred_seg)

    j_f_collection = []
    for video_name, video_scores in results.items():
        j_f_for_video = np.mean(np.array(video_scores), axis=0).tolist()
        results[video_name] = j_f_for_video
        j_f_collection.append(j_f_for_video)
    results['average'] = np.mean(np.array(j_f_collection), axis=0).tolist()
    return pretty_print(results)
Esempio n. 23
0
 def wrapper(*args, **kwargs):
     start_time = datetime.now()
     construct_print(f"a new epoch start: {start_time}")
     func(*args, **kwargs)
     construct_print(
         f"the time of the epoch: {datetime.now() - start_time}")
Esempio n. 24
0
def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数

    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name and exp_name != checkpoint["arch"]:
                # 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求
                raise Exception(
                    f"We can not match {exp_name} with {load_path}.")

            start_epoch = checkpoint["epoch"]
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint["net_state"])
            else:
                model.load_state_dict(checkpoint["net_state"])
            optimizer.load_state_dict(checkpoint["opti_state"])
            scheduler.load_state_dict(checkpoint["sche_state"])
            if checkpoint.get("amp_state", None):
                if amp:
                    amp.load_state_dict(checkpoint["amp_state"])
                else:
                    construct_print("You are not using amp.")
            else:
                construct_print("The state_dict of amp is None.")
            construct_print(f"Loaded '{load_path}' "
                            f"(will train at epoch"
                            f" {checkpoint['epoch']})")
            return start_epoch
        elif mode == "onlynet":
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)
            construct_print(f"Loaded checkpoint '{load_path}' "
                            f"(only has the model's weight params)")
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")
Esempio n. 25
0
    def __init__(self, exp_name: str, arg_dict: dict, path_dict: dict):
        super(Solver, self).__init__()
        self.exp_name = exp_name
        self.arg_dict = arg_dict
        self.path_dict = path_dict

        self.dev = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.to_pil = transforms.ToPILImage()

        self.tr_data_path = self.arg_dict["rgb_data"]["tr_data_path"]
        self.te_data_list = self.arg_dict["rgb_data"]["te_data_list"]

        self.save_path = self.path_dict["save"]
        self.save_pre = self.arg_dict["save_pre"]

        if self.arg_dict["tb_update"] > 0:
            self.tb_recorder = TBRecorder(tb_path=self.path_dict["tb"])
        if self.arg_dict["xlsx_name"]:
            self.xlsx_recorder = XLSXRecoder(xlsx_path=self.path_dict["xlsx"])

        # 依赖与前面属性的属性
        self.tr_loader = create_loader(
            data_path=self.tr_data_path,
            training=True,
            size_list=self.arg_dict["size_list"],
            prefix=self.arg_dict["prefix"],
            get_length=False,
        )
        self.end_epoch = self.arg_dict["epoch_num"]
        self.iter_num = self.end_epoch * len(self.tr_loader)

        if hasattr(network_lib, self.arg_dict["model"]):
            self.net = getattr(network_lib,
                               self.arg_dict["model"])().to(self.dev)
        else:
            raise AttributeError
        pprint(self.arg_dict)

        if self.arg_dict["resume_mode"] == "test":
            # resume model only to test model.
            # self.start_epoch is useless
            resume_checkpoint(
                model=self.net,
                load_path=self.path_dict["final_state_net"],
                mode="onlynet",
            )
            return

        self.loss_funcs = [
            torch.nn.BCEWithLogitsLoss(
                reduction=self.arg_dict["reduction"]).to(self.dev)
        ]
        if self.arg_dict["use_aux_loss"]:
            self.loss_funcs.append(CEL().to(self.dev))

        self.opti = make_optimizer(
            model=self.net,
            optimizer_type=self.arg_dict["optim"],
            optimizer_info=dict(
                lr=self.arg_dict["lr"],
                momentum=self.arg_dict["momentum"],
                weight_decay=self.arg_dict["weight_decay"],
                nesterov=self.arg_dict["nesterov"],
            ),
        )
        self.sche = make_scheduler(
            optimizer=self.opti,
            total_num=self.iter_num
            if self.arg_dict["sche_usebatch"] else self.end_epoch,
            scheduler_type=self.arg_dict["lr_type"],
            scheduler_info=dict(lr_decay=self.arg_dict["lr_decay"],
                                warmup_epoch=self.arg_dict["warmup_epoch"]),
        )

        # AMP
        if self.arg_dict["use_amp"]:
            construct_print("Now, we will use the amp to accelerate training!")
            from apex import amp

            self.amp = amp
            self.net, self.opti = self.amp.initialize(self.net,
                                                      self.opti,
                                                      opt_level="O1")
        else:
            self.amp = None

        if self.arg_dict["resume_mode"] == "train":
            # resume model to train the model
            self.start_epoch = resume_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                amp=self.amp,
                exp_name=self.exp_name,
                load_path=self.path_dict["final_full_net"],
                mode="all",
            )
        else:
            # only train a new model.
            self.start_epoch = 0
Esempio n. 26
0
    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            for train_batch_id, train_data in enumerate(self.tr_loader):
                curr_iter = curr_epoch * len(self.tr_loader) + train_batch_id

                self.opti.zero_grad()
                train_inputs, train_masks, *train_other_data = train_data
                train_inputs = train_inputs.to(self.dev, non_blocking=True)
                train_masks = train_masks.to(self.dev, non_blocking=True)
                train_preds = self.net(train_inputs)

                train_loss, loss_item_list = self.total_loss(
                    train_preds, train_masks)
                train_loss.backward()
                self.opti.step()

                if self.args["sche_usebatch"]:
                    if self.args["lr_type"] == "poly":
                        self.sche.step(curr_iter + 1)
                    else:
                        raise NotImplementedError

                # 仅在累计的时候使用item()获取数据
                train_iter_loss = train_loss.item()
                train_batch_size = train_inputs.size(0)
                train_loss_record.update(train_iter_loss, train_batch_size)

                # 显示tensorboard
                if (self.args["tb_update"] > 0
                        and (curr_iter + 1) % self.args["tb_update"] == 0):
                    self.tb.add_scalar("data/trloss_avg",
                                       train_loss_record.avg, curr_iter)
                    self.tb.add_scalar("data/trloss_iter", train_iter_loss,
                                       curr_iter)
                    self.tb.add_scalar("data/trlr",
                                       self.opti.param_groups[0]["lr"],
                                       curr_iter)
                    tr_tb_mask = make_grid(train_masks,
                                           nrow=train_batch_size,
                                           padding=5)
                    self.tb.add_image("trmasks", tr_tb_mask, curr_iter)
                    tr_tb_out_1 = make_grid(train_preds,
                                            nrow=train_batch_size,
                                            padding=5)
                    self.tb.add_image("trsodout", tr_tb_out_1, curr_iter)

                # 记录每一次迭代的数据
                if (self.args["print_freq"] > 0
                        and (curr_iter + 1) % self.args["print_freq"] == 0):
                    log = (
                        f"[I:{curr_iter}/{self.iter_num}][E:{curr_epoch}:{self.end_epoch}]>"
                        f"[{self.model_name}]"
                        f"[Lr:{self.opti.param_groups[0]['lr']:.7f}]"
                        f"[Avg:{train_loss_record.avg:.5f}|Cur:{train_iter_loss:.5f}|"
                        f"{loss_item_list}]")
                    print(log)
                    make_log(self.path["tr_log"], log)

            # 根据周期修改学习率
            if not self.args["sche_usebatch"]:
                if self.args["lr_type"] == "poly":
                    self.sche.step(curr_epoch + 1)
                else:
                    raise NotImplementedError

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            self.save_checkpoint(
                curr_epoch + 1,
                full_net_path=self.path['final_full_net'],
                state_net_path=self.path['final_state_net'])  # 保存参数

        total_results = {}
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader, self.te_length = create_loader(data_path=data_path,
                                                           mode='test',
                                                           get_length=True)
            self.save_path = os.path.join(self.path["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(
                    f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self.test(save_pre=self.save_pre)
            msg = (
                f"Results on the testset({data_name}:'{data_path}'): {results}"
            )
            construct_print(msg)
            make_log(self.path["te_log"], msg)

            total_results[data_name.upper()] = results
        # save result into xlsx file.
        write_xlsx(self.model_name, total_results)
Esempio n. 27
0
 def freeze_bn(self):
     construct_print("We will freeze all BN layers.")
     for m in self.modules():
         if isinstance(m, nn.BatchNorm2d):
             m.eval()
Esempio n. 28
0
import shutil
from datetime import datetime

from config import arg_config, proj_root
from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed
from utils.solver import Solver

if __name__ == '__main__':
    construct_print(f"{datetime.now()}: Initializing...")
    construct_print(f"Project Root: {proj_root}")
    init_start = datetime.now()

    exp_name = construct_exp_name(arg_config)
    path_config = construct_path(
        proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"],
    )
    pre_mkdir(path_config)
    set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] != None)

    solver = Solver(exp_name, arg_config, path_config)
    construct_print(f"Total initialization time:{datetime.now() - init_start}")

    shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"])
    shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"])

    construct_print(f"{datetime.now()}: Start...")
    if arg_config["resume_mode"] == "test" or arg_config["resume_mode"] == "measure":
        solver.test()
    else:
        solver.train()
    construct_print(f"{datetime.now()}: End...")
Esempio n. 29
0
        self.train_mask_transform = transforms.ToTensor()

    def __getitem__(self, index):
        img_path, mask_path = self.imgs[index]

        img = Image.open(img_path)
        mask = Image.open(mask_path)
        if len(img.split()) != 3:
            img = img.convert("RGB")
        if len(mask.split()) == 3:
            mask = mask.convert("L")

        img, mask = self.train_joint_transform(img, mask)
        mask = self.train_mask_transform(mask)
        img = self.train_img_transform(img)

        if self.use_bigt:
            mask = mask.ge(0.5).float()  # 二值化

        img_name = (img_path.split(os.sep)[-1]).split(".")[0]

        return img, mask, img_name

    def __len__(self):
        return len(self.imgs)


if __name__ == "__main__":
    img_list = _make_train_dataset_from_list()
    construct_print(len(img_list))
Esempio n. 30
0
        # 根据周期修改学习率
        if not scheduler_usebatch:
            scheduler.step()

        # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
        save_checkpoint(exp_name=exp_name,
                        model=model,
                        current_epoch=curr_epoch + 1,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        scaler=scaler,
                        full_net_path=path_dict["final_full_net"],
                        state_net_path=path_dict["final_state_net"])


construct_print(f"{datetime.now()}: Initializing...")
construct_print(f"Project Root: {proj_root}")
pprint(arg_config)

# construct exp_name
exp_name = construct_exp_name(arg_dict=arg_config,
                              extra_dicts=[optimizer_config, schedule_config])
path_dict = construct_path(proj_root=proj_root, exp_name=exp_name)

initialize_seed_cudnn(seed=0, use_cudnn_benchmark=False \
    if arg_config["use_mstrain"] else arg_config["use_cudnn_benchmark"])
pre_mkdir(path_config=path_dict)
pre_copy(
    main_file_path=__file__,
    all_config=dict(args=arg_config,
                    path=path_dict,