Пример #1
0
from datetime import datetime

from config import arg_config, proj_root
from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed
from utils.solver import Solver

if __name__ == '__main__':
    construct_print(f"{datetime.now()}: Initializing...")
    construct_print(f"Project Root: {proj_root}")
    init_start = datetime.now()

    exp_name = construct_exp_name(arg_config)
    path_config = construct_path(
        proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"],
    )
    pre_mkdir(path_config)
    set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] != None)

    solver = Solver(exp_name, arg_config, path_config)
    construct_print(f"Total initialization time:{datetime.now() - init_start}")

    shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"])
    shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"])

    construct_print(f"{datetime.now()}: Start...")
    if arg_config["resume_mode"] == "test" or arg_config["resume_mode"] == "measure":
        solver.test()
    else:
        solver.train()
    construct_print(f"{datetime.now()}: End...")
Пример #2
0
import shutil
from datetime import datetime

from utils.config import arg_config, path_config, proj_root
from utils.misc import construct_print, pre_mkdir, set_seed
from utils.solver import Solver

construct_print(f"{datetime.now()}: Initializing...")
construct_print(f"Project Root: {proj_root}")
init_start = datetime.now()
set_seed(0)
pre_mkdir()
solver = Solver(arg_config, path_config)
construct_print(f"Total initialization time:{datetime.now() - init_start}")

shutil.copy(f"{proj_root}/utils/config.py", path_config["cfg_log"])
shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"])

construct_print(f"{datetime.now()}: Start training...")
solver.train()
construct_print(f"{datetime.now()}: End training...")
Пример #3
0
                        full_net_path=path_dict["final_full_net"],
                        state_net_path=path_dict["final_state_net"])


construct_print(f"{datetime.now()}: Initializing...")
construct_print(f"Project Root: {proj_root}")
pprint(arg_config)

# construct exp_name
exp_name = construct_exp_name(arg_dict=arg_config,
                              extra_dicts=[optimizer_config, schedule_config])
path_dict = construct_path(proj_root=proj_root, exp_name=exp_name)

initialize_seed_cudnn(seed=0, use_cudnn_benchmark=False \
    if arg_config["use_mstrain"] else arg_config["use_cudnn_benchmark"])
pre_mkdir(path_config=path_dict)
pre_copy(
    main_file_path=__file__,
    all_config=dict(args=arg_config,
                    path=path_dict,
                    opti=optimizer_config,
                    sche=schedule_config),
    proj_root=proj_root,
)

if arg_config["tb_update"] > 0:
    tb_recorder = TBRecorder(path_dict["tb"])

tr_data_info = arg_config["data"]["tr"]
tr_loader_info = create_loader(
    data_path=arg_config["data"]["tr"],
Пример #4
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        self.dev = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.to_pil = transforms.ToPILImage()
        pprint(self.args)

        self.data_mode = self.args["data_mode"]
        if self.args["suffix"]:
            self.model_name = self.args["model"] + "_" + self.args["suffix"]
        else:
            self.model_name = self.args["model"]
        self.path = construct_path_dict(proj_root=proj_root,
                                        exp_name=self.model_name)

        pre_mkdir(path_config=self.path)
        shutil.copy(f"{proj_root}/config.py", self.path["cfg_log"])
        shutil.copy(f"{proj_root}/train.py", self.path["trainer_log"])

        if self.data_mode == "RGBD":
            self.tr_data_path = self.args["rgbd_data"]["tr_data_path"]
            self.te_data_list = self.args["rgbd_data"]["te_data_list"]
        elif self.data_mode == "RGB":
            self.tr_data_path = self.args["rgb_data"]["tr_data_path"]
            self.te_data_list = self.args["rgb_data"]["te_data_list"]
        else:
            raise NotImplementedError

        self.save_path = self.path["save"]
        self.save_pre = self.args["save_pre"]

        self.tr_loader = create_loader(
            data_path=self.tr_data_path,
            mode="train",
            get_length=False,
            data_mode=self.data_mode,
        )

        self.net = getattr(network,
                           self.args["model"])(pretrained=True).to(self.dev)

        # 损失函数
        self.loss_funcs = [
            BCELoss(reduction=self.args["reduction"]).to(self.dev)
        ]
        if self.args["use_aux_loss"]:
            self.loss_funcs.append(HEL().to(self.dev))

        # 设置优化器
        self.opti = self.make_optim()

        # 训练相关
        self.end_epoch = self.args["epoch_num"]
        if self.args["resume"]:
            try:
                self.resume_checkpoint(load_path=self.path["final_full_net"],
                                       mode="all")
            except:
                print(
                    f"{self.path['final_full_net']} does not exist and we will load {self.path['final_state_net']}"
                )
                self.resume_checkpoint(load_path=self.path["final_state_net"],
                                       mode="onlynet")
                self.start_epoch = self.end_epoch
        else:
            self.start_epoch = 0
        self.iter_num = self.end_epoch * len(self.tr_loader)