예제 #1
0
    def __init__(self, root, list_file, train, transform, transform_att,
                 input_size):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str) path to index file.
          train: (boolean) train or test.
          transform: ([transforms]) image transforms.
          input_size: (int) model input size.
        '''
        self.root = root
        self.train = train
        self.transform = transform
        self.transform_att = transform_att
        self.input_size = input_size

        self.fnames = []
        self.boxes = []
        self.ids = []
        self.get_att = seg_attention(input_size)

        self.encoder = DataEncoder()

        file_list = csv.reader(open(list_file, 'r'))
        file_list = list(file_list)
        for csv_content in file_list:
            self.fnames.append(csv_content[0])
            self.ids.append(int(csv_content[1]))

        self.thresh = torch.nn.Hardtanh(min_val=0, max_val=1)
예제 #2
0
    def __init__(self, root, train, transform, transform_att, input_size):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str) path to index file.
          train: (boolean) train or test.
          transform: ([transforms]) image transforms.
          input_size: (int) model input size.
        '''
        self.root = root
        self.train = train
        self.transform = transform
        self.transform_att = transform_att
        self.input_size = input_size
        self.get_att = seg_attention(input_size)
        self.thresh = torch.nn.Hardtanh(min_val=0, max_val=1) 

        fnames = []
        self.boxes = []
        ids = []
        self.ids_list = list(range(2874))
        self.im_name_list = []

        self.encoder = DataEncoder()
        
        file_list = os.listdir(root)
        # 2874
        self.im_name_train = file_list
        self.ids_train = list(range(len(file_list)))
예제 #3
0
    def __init__(self, root, list_file, train, transform, transform_att, input_size):
        '''
        Args:
          root: (str) ditectory to images.
          list_file: (str) path to index file.
          train: (boolean) train or test.
          transform: ([transforms]) image transforms.
          input_size: (int) model input size.
        '''
        self.root = root
        self.train = train
        self.transform = transform
        self.transform_att = transform_att
        self.input_size = input_size
        self.get_att = seg_attention(input_size)
        self.thresh = torch.nn.Hardtanh(min_val=0, max_val=1) 

        fnames = []
        self.boxes = []
        ids = []
        self.ids_list = list(range(2874))
        self.im_name_list = []

        self.encoder = DataEncoder()
        
        file_list = csv.reader(open(list_file,'r'))
        file_list = list(file_list)
        # 2874
        for content_counter in range(len(file_list)):
            fnames.append(file_list[content_counter][0])
            ids.append(int(file_list[content_counter][1]))
        
        '''
        for id_counter in range(2874):
            seq_num = ids.index(id_counter)
            self.im_name_list.append(fnames[seq_num])
            del(ids[seq_num])
            del(fnames[seq_num])
        
        self.im_name_valid = fnames[:400]
        self.im_name_train = fnames[400:]+self.im_name_list
        self.ids_valid = ids[:400]
        self.ids_train = ids[400:]+self.ids_list
        '''
        self.im_name_train = fnames
        self.ids_train = ids
예제 #4
0
    def __init__(self, 
                 root, 
                 list_file,  
                 input_size, 
                 eval_num,
                 addition_root=None, 
                 addition_list=None,
                 transform=transform, 
                 transform_att=transform_pure, 
                 att_flag=False, 
                 flip_flag=True, 
                 random_crop_flag=False,
                 center_crop_flag=True,
                 align_flag=True, 
                 state = 'train'):
        self.root = root
        self.list_file = list_file
        if addition_root is not None:
            self.addition_root = addition_root
            self.addition_list = addition_list
        self.input_size = input_size
        self.eval_num = eval_num
        self.att_flag = att_flag
        self.align_flag = align_flag
        self.transform = transform
        self.transform_att = transform_att
        self.random_crop_flag = random_crop_flag
        self.center_crop_flag = center_crop_flag
        self.flip_flag = flip_flag

        fnames = []
        ids = []
        self.encoder = DataEncoder()

        file_list = list(csv.reader(open(list_file, 'r')))
        for content_counter in file_list:
            fnames.append(os.path.join(self.root, content_counter[0]))
            ids.append(int(content_counter[1]))
        
        im_name_list = []
        for id_counter in range(2874):
            seq_num = ids.index(id_counter)
            im_name_list.append(fnames[seq_num])
            del(ids[seq_num])
            del(fnames[seq_num])

        ids_list = list(range(2874))
        self.im_name_valid = fnames[:self.eval_num]
        self.im_name_train = fnames[self.eval_num:]+im_name_list
        self.ids_valid = ids[:self.eval_num]
        self.ids_train = ids[self.eval_num:]+ids_list
        
        if self.align_flag:
            self.align_detect = dlib.get_frontal_face_detector()
            predicter_path = "./model/shape_predictor_5_face_landmarks.dat"
            self.sp = dlib.shape_predictor(predicter_path)

        if addition_root is not None:
            self.addition_root = addition_root
            self.addition_list = addition_list

            add_csv = list(csv.reader(open(self.addition_list, 'r')))
            self.add_id = []
            self.add_img = []
            for img_list in add_csv:
                self.add_id.append(int(img_list[1]))
                folder_dir = re.sub(r'_\d\d\d\d.jpg', '', img_list[0])
                img_dir = os.path.join(os.path.join(self.addition_root, folder_dir), img_list[0])
                self.add_img.append(img_dir)

        if state == 'train':
            self.excuse_list = self.im_name_train
            self.excuse_ids = self.ids_train
        else:
            self.excuse_list = self.im_name_valid
            self.excuse_ids = self.ids_valid
        
        if addition_root is not None:
            self.excuse_list += self.add_img
            self.excuse_ids += self.add_id

        self.thresh = torch.nn.Hardtanh(min_val=0, max_val=1)
        self.get_att = seg_attention(input_size)