def get_random_samples(dataset_root, num_samples):
    transform = transforms.Compose([
        ToPILImage(),
        Resize((64, 64)),
        ToTensor(),
        FlattenTransform()
    ])

    dataset = CelebaDataset(dataset_root, TRAIN, transform)
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=True)

    return loader.__iter__().__next__()
Esempio n. 2
0
def main():
    train_val_transform = transforms.Compose(
        [
            # transforms.Pad((4, 4, 4, 4)),
            # transforms.RandomCrop((32, 32)),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # values are between [0, 1], we want [-1, 1]
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    rf = RobotFashion(
        os.getcwd(),
        "train",
        download_if_missing=True,
        transform=train_val_transform,
        subset_ratio=1,
    )

    print("samples in data:", len(rf))

    def collate(inputs):
        images = list()
        labels = list()

        for image, label in inputs:
            images.append(image)
            labels.append(label)

        return images, labels

    data_loader = DataLoader(rf, num_workers=1, batch_size=1, collate_fn=collate)

    # get_data_loader_dist(data_loader)

    x = data_loader.__iter__().__next__()
    print(x)

    hparams = get_parser()
    model: FasterRCNNWithRobotFashion = FasterRCNNWithRobotFashion(hparams)

    y = model(x)
    print(y)
Esempio n. 3
0
class FCNRunner(object):
    def __init__(self, args):
        self.basedir = args.basedir
        self.expname = args.expname

        self.train_set = BlenderDataset(args.datadir, 'train', args.testskip)
        self.val_set = BlenderDataset(args.datadir, 'val', args.testskip)
        self.train_loader = DataLoader(self.train_set,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers)
        self.val_loader = DataLoader(self.val_set,
                                     batch_size=1,
                                     shuffle=True,
                                     num_workers=args.num_workers)

        embedder_pos, in_ch = get_embedder(multires=args.multires)
        embedder_dir, in_ch_dir = get_embedder(multires=args.multires_views)
        in_channels = in_ch + in_ch_dir
        self.model = FCN(layers=8, in_channels=in_channels)
        self.embedder_pos = embedder_pos
        self.embedder_dir = embedder_dir
        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                 lr=args.lr,
                                 betas=(0.9, 0.999))

        self.checkpoint = args.checkpoint
        self.num_epoch = args.num_epoch
        self.val_epoch = args.val_epoch
        self.i_print = args.i_print


    def load_checkpoint(self, path):
        ckpt = torch.load(path)
        self.model.load_state_dict(ckpt['model'])
        self.optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['start_epoch']
        return start_epoch

    def save_checkpoint(self, path, epoch):
        torch.save({
            'start_epoch': epoch,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }, path)

    def train(self):
        global_step = 0
        start_epoch = 0
        if self.checkpoint is not None:
            start_epoch = self.load_checkpoint(self.checkpoint)
            global_step = start_epoch * len(self.train_set)

        log_dir = os.path.join(self.basedir, self.expname)
        ckpts = [os.path.join(log_dir, f) for f in sorted(os.listdir(log_dir)) if '.pth' in f]
        if len(ckpts) > 0:
            print('Found checkpoints', ckpts[-1])
            start_epoch = self.load_checkpoint(ckpts[-1])
            global_step = start_epoch * len(self.train_set)

        self.model.to(device)
        start_time = time.time()
        for epoch in range(start_epoch, self.num_epoch):
            for step, data in enumerate(self.train_loader):
                time0 = time.time()
                gt_img = data['img'].to(device)
                rays_o = data['rays_o'].to(device)
                rays_d = data['rays_d'].to(device)
                embedding_pos = self.embedder_pos(rays_o).permute((0, 3, 1, 2))
                embedding_dir = self.embedder_dir(rays_d).permute((0, 3, 1, 2))

                img = self.model.forward(embedding_pos, embedding_dir)
                img_loss = mse(img, gt_img.permute((0, 3, 1, 2)))

                loss = img_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                elapsed = time.time()-time0
                if global_step % self.i_print==0:
                    print(f'[Train {global_step}] loss:{loss.item()} time:{elapsed} sec')
                global_step += 1

            if epoch % self.val_epoch==0:
                with torch.no_grad():
                    data = next(self.val_loader.__iter__())
                    gt_img = data['img'].to(device)
                    rays_o = data['rays_o'].to(device)
                    rays_d = data['rays_d'].to(device)
                    embedding_pos = self.embedder_pos(rays_o).permute((0, 3, 1, 2))
                    embedding_dir = self.embedder_dir(rays_d).permute((0, 3, 1, 2))

                    img = self.model.forward(embedding_pos, embedding_dir)

                    img_loss = mse(img, gt_img.permute((0, 3, 1, 2)))
                    print(f'[Val {global_step}] loss:{img_loss.item()}')

                    rgb_img = to8b(img[0].cpu().numpy().transpose((1, 2, 0)))
                    imageio.imwrite(os.path.join(self.basedir, self.expname, f"{global_step}.png"), rgb_img)


                self.save_checkpoint(os.path.join(self.basedir, self.expname, f"{epoch:03d}.pth"), epoch)

        total_time = (time.time() - start_time) / 60.0
        print(f'{total_time:.4f} min')
Esempio n. 4
0
class LibDataProcessor(RawDataProcessor):
    def __init__(self, batch_size, data_workers, dataset_args):
        self.batch_size = batch_size
        self.data_workers = data_workers
        self.dataset = ReaderDataset(
            uncased_question=dataset_args.dataProcessor.uncased_question,
            uncased_doc=dataset_args.dataProcessor.uncased_doc,
            use_qemb=dataset_args.params.use_qemb,
            use_in_question=dataset_args.params.use_in_question,
            use_pos=dataset_args.params.use_pos,
            use_ner=dataset_args.params.use_ner,
            use_lemma=dataset_args.params.use_lemma,
            use_tf=dataset_args.params.use_tf)
        self.loader = None
        self.word_dict = None
        self.feature_dict = None

    def set_loader(self):
        self.loader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            num_workers=self.data_workers,
            collate_fn=self.batchify,
        )

    def load_data(self, filename):
        self.dataset.load_data(filename)

    @overrides
    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        return self.loader.__iter__()

    def lengths(self):
        return [(len(ex['document']), len(ex['question']))
                for ex in self.dataset]

    def set_utils(self, word_dict, feature_dict):
        self.dataset.set_utils(word_dict, feature_dict)
        self.set_loader()

    @staticmethod
    def batchify(batch):
        """
        Gather a batch of individual examples into one batch

        :param batch: batch of examples
        :return: docs_data, x1_f, docs_mask, questions_data, questions_mask, start, end, ids
        """
        """Gather a batch of individual examples into one batch."""
        NUM_INPUTS = 3
        NUM_TARGETS = 2
        NUM_EXTRA = 1

        ids = [ex[-1] for ex in batch]
        docs = [ex[0] for ex in batch]
        features = [ex[1] for ex in batch]
        questions = [ex[2] for ex in batch]

        # Batch documents and features
        max_length = max([d.size(0) for d in docs])
        docs_indices = torch.LongTensor(len(docs), max_length).zero_()
        docs_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
        if features[0] is None:
            docs_feature = None
        else:
            docs_feature = torch.zeros(len(docs), max_length,
                                       features[0].size(1))
        for i, d in enumerate(docs):
            docs_indices[i, :d.size(0)].copy_(d)
            docs_mask[i, :d.size(0)].fill_(0)
            if docs_feature is not None:
                docs_feature[i, :d.size(0)].copy_(features[i])

        # Batch questions
        max_length = max([q.size(0) for q in questions])
        questions_indices = torch.LongTensor(len(questions),
                                             max_length).zero_()
        questions_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
        for i, q in enumerate(questions):
            questions_indices[i, :q.size(0)].copy_(q)
            questions_mask[i, :q.size(0)].fill_(0)

        # Maybe return without targets
        if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
            return docs_indices, docs_feature, docs_mask, questions_indices, questions_mask, ids

        elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
            # ...Otherwise add targets
            if torch.is_tensor(batch[0][3]):
                start = torch.cat([ex[3] for ex in batch])
                end = torch.cat([ex[4] for ex in batch])
            else:
                start = [ex[3] for ex in batch]
                end = [ex[4] for ex in batch]
        else:
            raise RuntimeError('Incorrect number of inputs per example.')

        return docs_indices, docs_feature, docs_mask, questions_indices, questions_mask, start, end, ids

    @classmethod
    def from_params(cls, params: Params) -> 'LibDataProcessor':
        return cls(params.dataProcessor.batch_size,
                   params.dataProcessor.data_workers, params)