def train(self, epochs, start_epoch=0): max_test_acc = 0.0 for epoch in range(start_epoch, epochs): Tools.print() Tools.print("Start Epoch {}".format(epoch)) self._lr(epoch) Tools.print('Epoch:{:02d},lr={:.4f}'.format( epoch, self.optimizer.param_groups[0]['lr'])) epoch_loss, epoch_train_acc, epoch_train_acc_k = self._train_epoch( ) self._save_checkpoint(self.model, self.root_ckpt_dir, epoch) test_loss, epoch_test_acc, epoch_test_acc_k = self.test() result_str = 'Epoch:{:02d}, Train:{:.4f}-{:.4f}/{:.4f} Test:{:.4f}-{:.4f}/{:.4f}'.format( epoch, epoch_train_acc, epoch_train_acc_k, epoch_loss, epoch_test_acc, epoch_test_acc_k, test_loss) Tools.print(result_str) if epoch_test_acc > max_test_acc: max_test_acc = epoch_test_acc Tools.write_to_txt(param.log_path, result_str + "\n") pass pass pass
def __init__(self): self.device = param.device self.root_ckpt_dir = param.root_ckpt_dir self.lr_s = param.lr self.train_dataset = MyDataset(data_root_path=param.data_root, is_train=True, image_size=param.image_size, sp_size=param.sp_size, padding=param.padding) self.test_dataset = MyDataset(data_root_path=param.data_root, is_train=False, image_size=param.image_size, sp_size=param.sp_size, padding=param.padding) self.train_loader = DataLoader( self.train_dataset, batch_size=param.batch_size, shuffle=True, num_workers=param.num_workers, collate_fn=self.train_dataset.collate_fn) self.test_loader = DataLoader(self.test_dataset, batch_size=param.batch_size, shuffle=False, num_workers=param.num_workers, collate_fn=self.test_dataset.collate_fn) self.model = MyGCNNet().to(self.device) if param.is_sgd: self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr_s[0][1], momentum=0.9, weight_decay=param.weight_decay) else: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr_s[0][1], weight_decay=param.weight_decay) self.loss_class = nn.CrossEntropyLoss().to(self.device) param_num = "Total param: {}".format(self._view_model_param( self.model)) Tools.print(param_num) Tools.write_to_txt(param.log_path, param_num + "\n") pass
import os import random from alisuretool.Tools import Tools split_num = 4 data_dir = "/home/ubuntu/data1.5TB/C3D/UCF-101" train_list, test_list = [], [] video_dir = [os.path.join(data_dir, _) for _ in sorted(os.listdir(data_dir))] for video_index, video_one in enumerate(video_dir): video_one_dir = [ os.path.join(video_one, _) for _ in sorted(os.listdir(video_one)) ] for video_one_dir_one in video_one_dir: if random.randint(0, split_num - 1) > 0: train_list.append("{} {}\n".format(video_one_dir_one, video_index)) else: test_list.append("{} {}\n".format(video_one_dir_one, video_index)) pass pass pass Tools.write_to_txt("train.list", train_list, reset=True) Tools.write_to_txt("test.list", test_list, reset=True)
def print(self): print(self.name) print(self.__dict__) Tools.write_to_txt(self.log_path, self.name + "\n") pass
mask_path = "/mnt/4T/Data/SOD/DUTS/DUTS-TE/DUTS-TE-Mask" result_path = "/mnt/4T/ALISURE/GCN/PyTorchGCN_Result/PYG_ChangeGCN_GCNAtt_NoAddGCN_NoAttRes/DUTS-TE/SOD" mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask" result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes/DUTS-TE/SOD" Max F-measre: 0.894474 Precision: 0.920642 Recall: 0.817062 MAE: 0.0368984 mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask" result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_NewPool/DUTS-TE/SOD" Max F-measre: 0.894308 Precision: 0.916229 Recall: 0.828253 MAE: 0.0370773 mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask" result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_Sigmoid/DUTS-TE/SOD" """ if __name__ == '__main__': mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask" result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_Sigmoid/DUTS-TE/SOD" _result_files = get_file(mask_path, result_path) _txt = "\n".join(_result_files) Tools.write_to_txt("salmetric.txt", _txt, reset=True) pass