Пример #1
0
 def __init__(self, dataset_class, txt_file, root_dir, transform, num_workers):
     super(TrainClass, self).__init__()
     self.args = args
     self.device = torch.device("cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")
     self.writer = SummaryWriter('log_three')
     self.model = Stage2Model().to(self.device)
     if args.optimizer == 'Adam':
         self.optimizer = [optim.Adam(self.model.parameters(), self.args.lr)
                           for _ in range(4)]
     else:
         self.optimizer = [optim.SGD(self.model.parameters(), self.args.lr, momentum=self.args.momentum,
                                     weight_decay=self.args.weight_decay)
                           for _ in range(4)]
     self.criterion = [nn.CrossEntropyLoss()
                       for _ in range(4)]
     # self.metric = nn.CrossEntropyLoss()
     self.metric = [F1Accuracy()
                    for _ in range(4)]
     self.train_loader = None
     self.eval_loader = None
     self.ckpt_dir = "checkpoint_%s" % uuid
     self.display_freq = args.display_freq
     self.scheduler = [optim.lr_scheduler.StepLR(self.optimizer[i], step_size=5, gamma=0.5)
                       for i in range(4)]
     self.best_error = [float('Inf'), float('Inf'), float('Inf'), float('Inf')]
     self.best_accu = [float('-Inf'), float('-Inf'), float('-Inf'), float('-Inf')]
     self.load_dataset(dataset_class, txt_file, root_dir, transform, num_workers)
Пример #2
0
    def __init__(self, batch_size, is_shuffle, num_workers, state_files):
        super(Simple_two, self).__init__()
        self.args = None
        self.get_args()
        self.device = torch.device("cuda:%d" % self.args.cuda if torch.cuda.is_available() else "cpu")
        self.model1 = Stage1Model().to(self.device)
        self.model2 = Stage2Model().to(self.device)
        self.reverse = ReverseTModel().to(self.device)
        self.root_dir = "/data1/yinzi/datas"
        self.predict1 = None
        self.predict2 = None
        self.all_predict = None
        self.best_error = float('Inf')

        self.F1_name_list = ['eyebrow1', 'eyebrow2',
                             'eye1', 'eye2',
                             'nose', 'u_lip', 'i_mouth', 'l_lip']
        self.TP = {x: 0.0 + 1e-20
                   for x in self.F1_name_list}
        self.FP = {x: 0.0 + 1e-20
                   for x in self.F1_name_list}
        self.TN = {x: 0.0 + 1e-20
                   for x in self.F1_name_list}
        self.FN = {x: 0.0 + 1e-20
                   for x in self.F1_name_list}
        self.recall = {x: 0.0 + 1e-20
                       for x in self.F1_name_list}
        self.precision = {x: 0.0 + 1e-20
                          for x in self.F1_name_list}
        self.F1_list = {x: []
                        for x in self.F1_name_list}
        self.F1 = {x: 0.0 + 1e-20
                   for x in self.F1_name_list}

        self.recall_overall_list = {x: []
                                    for x in self.F1_name_list}
        self.precision_overall_list = {x: []
                                       for x in self.F1_name_list}
        self.recall_overall = 0.0
        self.precision_overall = 0.0
        self.F1_overall = 0.0
        self.dataset = None
        self.dataloader = None
        self.batch_size = batch_size
        self.is_shuffle = is_shuffle
        self.num_workers = num_workers
        self.state_files = state_files
        self.map_location = self.device
        self.stage2_loss_func = torch.nn.CrossEntropyLoss()
        self.optim2 = optim.Adam(self.model2.parameters(), self.args.lr)
        self.step = 0
        self.epoch = 0
        self.ckpt_dir = "checkpoints_%s" % uuid
        self.writer = SummaryWriter("logs")
        self.get_dataloader()
Пример #3
0
    def __init__(self, argus=args):
        super(TrainModel, self).__init__()
        self.label_channels = 9
        # ============== not neccessary ===============
        self.train_logger = None
        self.eval_logger = None
        self.args = argus

        # ============== neccessary ===============
        self.writer = SummaryWriter('log')
        self.step = 0
        self.epoch = 0
        self.best_error = float('Inf')
        self.best_accu = float('-Inf')

        self.device = torch.device(
            "cuda:%d" % args.cuda if torch.cuda.is_available() else "cpu")

        self.model = Stage2Model().to(self.device)
        self.model1 = Stage1Model().to(self.device)
        # self.reverse = ReverseTModel().to(self.device)
        # self.optimizer = optim.SGD(self.model.parameters(), self.args.lr,  momentum=0.9, weight_decay=0.0)
        self.optimizer = optim.Adam(self.model.parameters(), self.args.lr)
        self.criterion = nn.CrossEntropyLoss()
        # self.criterion = nn.BCEWithLogitsLoss()
        # self.metric = nn.CrossEntropyLoss()
        self.metric = F1Accuracy()
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=5,
                                                   gamma=0.5)

        self.train_loader = stage1_dataloaders['train']
        self.eval_loader = stage1_dataloaders['val']

        self.ckpt_dir = "checkpoints_%s" % uuid
        self.display_freq = args.display_freq

        # call it to check all members have been intiated
        self.check_init()
Пример #4
0
 def __init__(self, dataset_class, txt_file, root_dir, transform, num_workers):
     super(TrainMGPU, self).__init__(dataset_class, txt_file, root_dir, transform, num_workers)
     self.model = nn.DataParallel(Stage2Model(), device_ids=[0, 1, 2, 3, 4])
     self.model = self.model.to(self.device)