示例#1
0
    def __init__(self):
        self.train_frames = args.train_frames
        print("===> Building model")
        # Model Setting
        # prtrained model setting
        self.extractor = resnet101(num_classes=400,
                                   shortcut_type='B',
                                   cardinality=32,
                                   sample_size=args.crop_size,
                                   sample_duration=args.train_frames)

        # load pretrained model
        # eval 모드를 무조건 해야됨!
        weight = get_pretrain_weight(pretrain_path, self.extractor)
        self.extractor.load_state_dict(weight)
        self.extractor.eval()

        self.G = BaseNet(16384 * 4 * 4 * 4, 10)

        # optimizer
        self.optimizerG = optim.Adam(self.G.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-8)
        # loss
        self.BCE_loss = nn.BCELoss()
        self.L1_loss = nn.L1Loss()
        self.L2_loss = nn.MSELoss()
        # data
        self.train_data = train_data_loader
        # self.test_data = test_data_loader
        # training
        self.device = device
        self.epochs = args.epochs
        self.avg_G_loss_arr = []
        self.checkpoint = args.checkpoint

        # cuda
        if torch.cuda.is_available():
            self.extractor.cuda()
            self.G.cuda()

        # model save
        self.save_mname = args.save_model_name
        # checkpoint root
        make_dirs(log_path)
        self.log_dir = log_path + f'/{self.save_mname}'
        # self.load_check_dir =
        self.save_check_dir = make_dirs(self.log_dir) + '/' + 'checkpoint.pkl'

        # CSV logging system
        self.CSVlogger = LogCSV(log_dir=self.log_dir +
                                f"/{self.save_mname}_log.csv",
                                header=['epoch', 'avg_G_Loss', 'accuracy'])
示例#2
0
    def checkpoint_set(self):
        """
        Todo setting checkpoint argument
        :param self.save_mname: saved model name in each checkpoint
        :param self.log_dir: location of checkpoint saving folder
        :param self.save_check_dir: filename of checkpoint
        :param self.CSVlogger: CSVlogger setting for check a loss
        """
        # model save
        self.save_mname = args.save_model_name
        # checkpoint root
        make_dirs(log_path)
        self.log_dir = log_path + f'/{self.save_mname}'
        self.save_check_dir = make_dirs(self.log_dir)
        self.graph_dir = self.log_dir

        # CSV logging system
        self.CSVlogger = LogCSV(log_dir=self.log_dir + f"/{self.save_mname}_log.csv",
                                header=['epoch', 'avg_G_Loss', 'accuracy'])