示例#1
0
    def evaluate(self, x: GLOBAL.np.ndarray, y: GLOBAL.np.ndarray, batch_size: int = None):
        self.eval()
        x = GLOBAL.np.asarray(x)
        y = GLOBAL.np.asarray(y)
        if batch_size is not None:
            assert type(batch_size) is int
            val_dataset = data.DataSet(x, y)
            val_dataloader = data.DataLoader(val_dataset, batch_size)
            acc_list = []
            loss_list = []
            for idx, (batch_x, batch_y) in enumerate(val_dataloader):
                if idx == 0 or idx == len(val_dataloader) - 1:
                    self.init_graph_out_tensor(batch_x, batch_y)

                y_pred = self.forward(batch_x)
                metric = self.loss.metric(y_pred, batch_y)
                acc_list.append(metric[0])
                loss_list.append(metric[1])

            acc = GLOBAL.np.array(acc_list).mean().tolist()
            loss = GLOBAL.np.array(loss_list).mean().tolist()
        else:
            y_pred = self.forward(F.Tensor(x))
            acc, loss = self.loss.metric(y_pred, F.Tensor(y))
        return acc, loss
示例#2
0
def select_batch(dataset: data.Dataset,
                 size: int,
                 generator: torch.Generator = None):
    dataloader = data.DataLoader(dataset,
                                 batch_size=size,
                                 shuffle=True,
                                 generator=generator)
    return next(iter(dataloader))
示例#3
0
    def predict(self, x: GLOBAL.np.ndarray, batch_size: int = None):
        self.eval()
        if batch_size is not None:
            assert type(batch_size) is int
            test_dataset = data.DataSet(x)
            test_dataloader = data.DataLoader(test_dataset, batch_size)
            pred_list = []
            for batch_x in test_dataloader:
                y_pred = self.forward(batch_x)
                pred_list.append(y_pred)
            pred = F.concat(pred_list, axis=0)
        else:
            pred = self.forward(F.Tensor(x))

        return pred
示例#4
0
    def build_dataloader(self, datasets: T.Dict[str, data.Dataset], n_support: int = None, n_queries: int = None, generator: torch.Generator = None):
        n_support = n_support or self.n_support
        n_queries = n_queries or self.n_queries
        generator = generator or self.generator

        self.last_datasets = datasets
        dslist = list(datasets.values())

        target_len = max([len(ds) for ds in dslist])
        for i, dataset in enumerate(dslist):
            if len(dataset) < target_len:
                new_dataset = dataset + dataset
                while len(new_dataset) < target_len:
                    new_dataset += dataset
                dslist[i] = new_dataset

        queries = dslist[0]
        labels = [0] * len(dslist[0])

        for i in range(1, len(dslist)):
            queries += dslist[i]
            labels += [i] * len(dslist[i])

        # Select supports
        # Will be constant for all queries
        supports = []
        for dataset in dslist:
            support = self.select_batch(dataset, size=n_support, generator=generator)
            supports.append(support)

        def collate_fn(batch):
            queries = []
            labels = []
            for row in batch:
                query, label = row
                queries.append(query)
                labels.append(label)
            queries = torch.stack(queries, dim=0)
            labels = torch.tensor(labels, device=queries.device)
            return queries, labels, *supports

        dataset = data.dzip(queries, labels)
        return data.DataLoader(dataset, batch_size=self.n_queries, collate_fn=collate_fn, shuffle=True, generator=generator)
示例#5
0
def build_dataloader(datasets: T.Dict[str, data.Dataset],
                     n_support: int,
                     n_queries: int,
                     generator: torch.Generator = None):

    dslist = list(datasets.values())

    queries = dslist[0]
    labels = [0] * len(dslist[0])

    for i in range(1, len(dslist)):
        queries += dslist[i]
        labels += [i] * len(dslist[i])

    # Select supports
    # Will be constant for all queries
    supports = []
    for dataset in dslist:
        support = select_batch(dataset, size=n_support, generator=generator)
        supports.append(support)

    def collate_fn(batch):
        queries = []
        labels = []
        for row in batch:
            query, label = row
            queries.append(query)
            labels.append(label)
        queries = torch.stack(queries, dim=0)
        labels = torch.tensor(labels, device=queries.device)
        return queries, labels, *supports

    dataset = data.dzip(queries, labels)
    return data.DataLoader(dataset,
                           batch_size=n_queries,
                           collate_fn=collate_fn,
                           shuffle=True,
                           generator=generator)
示例#6
0
def fewshot_dataloader(datasets: T.List[data.Dataset], n_support: int, n_queries: int, generator: torch.Generator = None, select_batch_fn=select_batch):

    target_len = max([len(ds) for ds in datasets])
    for i, dataset in enumerate(datasets):
        if len(dataset) < target_len:
            new_dataset = dataset + dataset
            while len(new_dataset) < target_len:
                new_dataset += dataset
            datasets[i] = new_dataset

    queries = datasets[0]
    labels = [0] * len(datasets[0])

    for i in range(1, len(datasets)):
        queries += datasets[i]
        labels += [i] * len(datasets[i])

    # Select supports
    # Will be constant for all queries
    supports = []
    for dataset in datasets:
        support = select_batch_fn(dataset, size=n_support, generator=generator)
        supports.append(support)

    def collate_fn(batch):
        queries = []
        labels = []
        for row in batch:
            query, label = row
            queries.append(query)
            labels.append(label)
        queries = torch.stack(queries, dim=0)
        labels = torch.tensor(labels, device=queries.device)
        return queries, labels, *supports

    dataset = data.dzip(queries, labels)
    return data.DataLoader(dataset, batch_size=n_queries, collate_fn=collate_fn, shuffle=True, generator=generator)
示例#7
0
    def fit(self,
            x: GLOBAL.np.ndarray,
            y: GLOBAL.np.ndarray,
            batch_size: int = None,
            epochs: int = 1,
            verbose: int = 1,
            shuffle: bool = True,
            validation_data: Tuple[GLOBAL.np.ndarray] = None,
            validation_split: float = 0.,
            initial_epoch: int = 0,
            ) -> SummaryProfile:
        x = GLOBAL.np.asarray(x)
        y = GLOBAL.np.asarray(y)
        record_data_names = ['train_loss', 'train_acc']
        if validation_data is None and 0. < validation_split < 1.:
            split = int(x.shape[0] * validation_split)
            valid_x, valid_y = x[-split:], y[-split:]
            train_x, train_y = x[:-split], y[:-split]
            validation_data = (valid_x, valid_y)
            record_data_names.append('validation_loss')
            record_data_names.append('validation_acc')
        else:
            train_x, train_y = x, y

        train_dataset = data.DataSet(train_x, train_y)
        train_dataloader = data.DataLoader(train_dataset, batch_size, shuffle)
        v_acc, v_loss = 0., 0.
        with SummaryProfile(*record_data_names) as profile:
            for epoch in range(initial_epoch, epochs):
                epoch_start_time = time.time()
                if verbose != 0:
                    print('\033[1;31m Epoch[%d/%d]\033[0m' % (epoch + 1, epochs))
                progress_bar = ProgressBar(max_iter=len(train_x), verbose=verbose)
                for idx, (batch_x, batch_y) in enumerate(train_dataloader):
                    if GLOBAL.USE_CUDA:
                        batch_x.cuda_()
                        batch_y.cuda_()
                    else:
                        batch_x.cpu_()
                        batch_y.cpu_()
                    if idx == 0 or idx == len(train_dataloader) - 1:
                        self.init_graph_out_tensor(batch_x, batch_y)
                    self.train()
                    # reset trainable_variables grad
                    self.optimizer.zero_grad()
                    # forward
                    pred = self.forward(batch_x)
                    loss = self.loss.forward(pred, batch_y)
                    self.backward(loss)
                    self.optimizer.step()
                    epoch_time = time.time() - epoch_start_time
                    train_acc = self.loss.calc_acc(pred, batch_y)
                    profile.step('train_acc', train_acc)
                    profile.step('train_loss', loss.item())

                    if validation_data is not None:
                        valid_x, valid_y = validation_data
                        v_acc, v_loss = self.evaluate(valid_x, valid_y, batch_size=batch_size)
                        profile.step('validation_loss', v_loss)
                        profile.step('validation_acc', v_acc)

                    progress_bar.update(batch_x.shape[0])
                    progress_bar.console(verbose, epoch_time, loss.item(), train_acc, v_loss,
                                         v_acc)

                if verbose != 0:
                    print()

        return profile