def __init__(self): self.train_frames = args.train_frames print("===> Building model") # Model Setting # prtrained model setting self.extractor = resnet101(num_classes=400, shortcut_type='B', cardinality=32, sample_size=args.crop_size, sample_duration=args.train_frames) # load pretrained model # eval 모드를 무조건 해야됨! weight = get_pretrain_weight(pretrain_path, self.extractor) self.extractor.load_state_dict(weight) self.extractor.eval() self.G = BaseNet(16384 * 4 * 4 * 4, 10) # optimizer self.optimizerG = optim.Adam(self.G.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) # loss self.BCE_loss = nn.BCELoss() self.L1_loss = nn.L1Loss() self.L2_loss = nn.MSELoss() # data self.train_data = train_data_loader # self.test_data = test_data_loader # training self.device = device self.epochs = args.epochs self.avg_G_loss_arr = [] self.checkpoint = args.checkpoint # cuda if torch.cuda.is_available(): self.extractor.cuda() self.G.cuda() # model save self.save_mname = args.save_model_name # checkpoint root make_dirs(log_path) self.log_dir = log_path + f'/{self.save_mname}' # self.load_check_dir = self.save_check_dir = make_dirs(self.log_dir) + '/' + 'checkpoint.pkl' # CSV logging system self.CSVlogger = LogCSV(log_dir=self.log_dir + f"/{self.save_mname}_log.csv", header=['epoch', 'avg_G_Loss', 'accuracy'])
def checkpoint_set(self): """ Todo setting checkpoint argument :param self.save_mname: saved model name in each checkpoint :param self.log_dir: location of checkpoint saving folder :param self.save_check_dir: filename of checkpoint :param self.CSVlogger: CSVlogger setting for check a loss """ # model save self.save_mname = args.save_model_name # checkpoint root make_dirs(log_path) self.log_dir = log_path + f'/{self.save_mname}' self.save_check_dir = make_dirs(self.log_dir) self.graph_dir = self.log_dir # CSV logging system self.CSVlogger = LogCSV(log_dir=self.log_dir + f"/{self.save_mname}_log.csv", header=['epoch', 'avg_G_Loss', 'accuracy'])