Example #1
0
    MSE_Loss = nn.MSELoss().to(device)
    MSE_Loss_sum = nn.MSELoss(reduction='sum').to(device)
    L1_Loss = nn.L1Loss().to(device)

    alpha_weight = float(2. / args.step_iteration)

    for i in range(args.max_iter - iteration):
        if iteration >= args.step_iteration:
            if step < 3:
                alpha = 0
                iteration = 0
                step += 1
        alpha = min(1, alpha_weight * iteration)
        try:
            dat = dataloader.__next__()
            x2_target_image, x4_target_image, x8_target_image, input_image = dat
        except (OSError, StopIteration):
            dat = dataloader.__next__()
            x2_target_image, x4_target_image, x8_target_image, input_image = dat
        iteration += 1
        input_image = input_image.to(device)
        if step == 1:
            target_image = x2_target_image.to(device)
        elif step == 2:
            target_image = x4_target_image.to(device)
        elif step == 3:
            target_image = x8_target_image.to(device)

        train(generator, discriminator, face_align_net, g_optim, d_optim,
              input_image, target_image, step, iteration, alpha)
Example #2
0
class RecDataLoader:
    def __init__(self, dataset, batch_size, shuffle, num_workers, **kwargs):
        self.dataset = dataset
        self.process = dataset.process
        self.len_thresh = self.dataset._find_max_length() // 2
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.iteration = 0
        self.dataiter = None
        self.queue_1 = list()
        self.queue_2 = list()

    def __len__(self):
        return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \
            else len(self.dataset) // self.batch_size + 1

    def __iter__(self):
        return self

    def pack(self, batch_data):
        batch = {'img': [], 'label': []}
        # img tensor current shape: B,H,W,C
        all_same_height_images = [
            self.process.resize_with_specific_height(_['img'][0].numpy())
            for _ in batch_data
        ]
        max_img_w = max({m_img.shape[1] for m_img in all_same_height_images})
        # make sure max_img_w is integral multiple of 8
        max_img_w = int(np.ceil(max_img_w / 8) * 8)
        for i in range(len(batch_data)):
            _label = batch_data[i]['label'][0]
            img = self.process.normalize_img(
                self.process.width_pad_img(all_same_height_images[i],
                                           max_img_w))
            img = img.transpose([2, 0, 1])
            batch['img'].append(torch.FloatTensor(img))
            batch['label'].append(_label)
        batch['img'] = torch.stack(batch['img'])
        return batch

    def build(self):
        self.dataiter = DataLoader(self.dataset,
                                   batch_size=1,
                                   shuffle=self.shuffle,
                                   num_workers=self.num_workers).__iter__()

    def __next__(self):
        if self.dataiter == None:
            self.build()
        if self.iteration == len(self.dataset) and len(self.queue_2):
            batch_data = self.queue_2
            self.queue_2 = list()
            return self.pack(batch_data)
        if not len(self.queue_2) and not len(
                self.queue_1) and self.iteration == len(self.dataset):
            self.iteration = 0
            self.dataiter = None
            raise StopIteration
        # start iteration
        try:
            while True:
                # get data from origin dataloader
                temp = self.dataiter.__next__()
                self.iteration += 1
                # to different queue
                if len(temp['label'][0]) <= self.len_thresh:
                    self.queue_1.append(temp)
                else:
                    self.queue_2.append(temp)

                # to store batch data
                batch_data = None
                # queue_1 full, push to batch_data
                if len(self.queue_1) == self.batch_size:
                    batch_data = self.queue_1
                    self.queue_1 = list()
                # or queue_2 full, push to batch_data
                elif len(self.queue_2) == self.batch_size:
                    batch_data = self.queue_2
                    self.queue_2 = list()

                # start to process batch
                if batch_data is not None:
                    return self.pack(batch_data)
        # deal with last batch
        except StopIteration:
            batch_data = self.queue_1
            self.queue_1 = list()
            return self.pack(batch_data)
Example #3
0
class RecDataLoader:
    def __init__(self, dataset, config):
        self.dataset = dataset
        self.process = RecDataProcess(config)
        self.len_thresh = self.dataset._find_max_length() // 2
        self.batch_size = config.batch_size
        self.shuffle = config.shuffle
        self.num_workers = config.num_workers
        self.iteration = 0
        self.dataiter = None
        self.queue_1 = list()
        self.queue_2 = list()

    def __len__(self):
        return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \
            else len(self.dataset) // self.batch_size + 1

    def __iter__(self):
        return self

    def pack(self, batch_data):
        batch = [[], [], []]
        max_length = max({it[2].item() for it in batch_data})
        # img tensor current shape: C, H, W
        max_img_w = max({it[0].shape[-1] for it in batch_data})
        # make sure max_img_w is integral multiple of 8
        max_img_w = max_img_w + (8 - max_img_w % 8) if max_img_w % 8 != 0 else max_img_w
        for i in range(len(batch_data)):
            _img, _label, _length = batch_data[i]
            # trans to np array, roll back axis
            _img = _img.numpy().transpose([1, 2, 0])
            img = self.process.resize_normalize(_img, max_img_w)
            label = _label.tolist()[0] + [0] * (max_length - len(_label.tolist()[0]))
            batch[0].append(torch.FloatTensor(img))
            batch[1].append(torch.IntTensor(label))
            batch[2].append(torch.IntTensor([max_length]))

        return [torch.stack(batch[0]), torch.stack(batch[1]), torch.cat(batch[2])]

    def build(self):
        self.dataiter = DataLoader(self.dataset, batch_size=1,
                                   shuffle=self.shuffle, num_workers=self.num_workers).__iter__()

    def __next__(self):
        if self.dataiter == None:
            self.build()
        if self.iteration == len(self.dataset) and len(self.queue_2):
            batch_data = self.queue_2
            self.queue_2 = list()
            return self.pack(batch_data)
        if not len(self.queue_2) and not len(self.queue_1) and self.iteration == len(self.dataset):
            self.iteration = 0
            self.dataiter = None
            raise StopIteration
        # start iteration
        try:
            while True:
                # get data from origin dataloader
                temp = self.dataiter.__next__()
                self.iteration += 1
                # to different queue
                if temp[2].item() <= self.len_thresh:
                    self.queue_1.append(temp)
                else:
                    self.queue_2.append(temp)

                # to store batch data
                batch_data = None
                # queue_1 full, push to batch_data
                if len(self.queue_1) == self.batch_size:
                    batch_data = self.queue_1
                    self.queue_1 = list()
                # or queue_2 full, push to batch_data
                elif len(self.queue_2) == self.batch_size:
                    batch_data = self.queue_2
                    self.queue_2 = list()

                # start to process batch
                if batch_data is not None:
                    return self.pack(batch_data)
        # deal with last batch
        except StopIteration:
            batch_data = self.queue_1
            self.queue_1 = list()
            return self.pack(batch_data)