def train(self, epochs, start_epoch=0):
        max_test_acc = 0.0
        for epoch in range(start_epoch, epochs):
            Tools.print()
            Tools.print("Start Epoch {}".format(epoch))

            self._lr(epoch)
            Tools.print('Epoch:{:02d},lr={:.4f}'.format(
                epoch, self.optimizer.param_groups[0]['lr']))

            epoch_loss, epoch_train_acc, epoch_train_acc_k = self._train_epoch(
            )
            self._save_checkpoint(self.model, self.root_ckpt_dir, epoch)
            test_loss, epoch_test_acc, epoch_test_acc_k = self.test()

            result_str = 'Epoch:{:02d}, Train:{:.4f}-{:.4f}/{:.4f} Test:{:.4f}-{:.4f}/{:.4f}'.format(
                epoch, epoch_train_acc, epoch_train_acc_k, epoch_loss,
                epoch_test_acc, epoch_test_acc_k, test_loss)
            Tools.print(result_str)

            if epoch_test_acc > max_test_acc:
                max_test_acc = epoch_test_acc
                Tools.write_to_txt(param.log_path, result_str + "\n")
                pass

            pass
        pass
    def __init__(self):
        self.device = param.device
        self.root_ckpt_dir = param.root_ckpt_dir
        self.lr_s = param.lr

        self.train_dataset = MyDataset(data_root_path=param.data_root,
                                       is_train=True,
                                       image_size=param.image_size,
                                       sp_size=param.sp_size,
                                       padding=param.padding)
        self.test_dataset = MyDataset(data_root_path=param.data_root,
                                      is_train=False,
                                      image_size=param.image_size,
                                      sp_size=param.sp_size,
                                      padding=param.padding)

        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=param.batch_size,
            shuffle=True,
            num_workers=param.num_workers,
            collate_fn=self.train_dataset.collate_fn)
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=param.batch_size,
                                      shuffle=False,
                                      num_workers=param.num_workers,
                                      collate_fn=self.test_dataset.collate_fn)

        self.model = MyGCNNet().to(self.device)

        if param.is_sgd:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.lr_s[0][1],
                                             momentum=0.9,
                                             weight_decay=param.weight_decay)
        else:
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=self.lr_s[0][1],
                                              weight_decay=param.weight_decay)

        self.loss_class = nn.CrossEntropyLoss().to(self.device)

        param_num = "Total param: {}".format(self._view_model_param(
            self.model))
        Tools.print(param_num)
        Tools.write_to_txt(param.log_path, param_num + "\n")
        pass
示例#3
0
import os
import random
from alisuretool.Tools import Tools

split_num = 4
data_dir = "/home/ubuntu/data1.5TB/C3D/UCF-101"

train_list, test_list = [], []
video_dir = [os.path.join(data_dir, _) for _ in sorted(os.listdir(data_dir))]
for video_index, video_one in enumerate(video_dir):
    video_one_dir = [
        os.path.join(video_one, _) for _ in sorted(os.listdir(video_one))
    ]
    for video_one_dir_one in video_one_dir:
        if random.randint(0, split_num - 1) > 0:
            train_list.append("{} {}\n".format(video_one_dir_one, video_index))
        else:
            test_list.append("{} {}\n".format(video_one_dir_one, video_index))
            pass
        pass
    pass

Tools.write_to_txt("train.list", train_list, reset=True)
Tools.write_to_txt("test.list", test_list, reset=True)
 def print(self):
     print(self.name)
     print(self.__dict__)
     Tools.write_to_txt(self.log_path, self.name + "\n")
     pass
示例#5
0
mask_path = "/mnt/4T/Data/SOD/DUTS/DUTS-TE/DUTS-TE-Mask"
result_path = "/mnt/4T/ALISURE/GCN/PyTorchGCN_Result/PYG_ChangeGCN_GCNAtt_NoAddGCN_NoAttRes/DUTS-TE/SOD"

mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask"
result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes/DUTS-TE/SOD"

Max F-measre: 0.894474
Precision:    0.920642
Recall:       0.817062
MAE:          0.0368984
mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask"
result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_NewPool/DUTS-TE/SOD"


Max F-measre: 0.894308
Precision:    0.916229
Recall:       0.828253
MAE:          0.0370773
mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask"
result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_Sigmoid/DUTS-TE/SOD"
"""

if __name__ == '__main__':
    mask_path = "/media/ubuntu/data1/ALISURE/DUTS/DUTS-TE/DUTS-TE-Mask"
    result_path = "/media/ubuntu/data1/ALISURE/PyTorchGCN_Result/PYG_GCNAtt_NoAddGCN_NoAttRes_Sigmoid/DUTS-TE/SOD"

    _result_files = get_file(mask_path, result_path)
    _txt = "\n".join(_result_files)
    Tools.write_to_txt("salmetric.txt", _txt, reset=True)
    pass