def evaluate(self, x: GLOBAL.np.ndarray, y: GLOBAL.np.ndarray, batch_size: int = None): self.eval() x = GLOBAL.np.asarray(x) y = GLOBAL.np.asarray(y) if batch_size is not None: assert type(batch_size) is int val_dataset = data.DataSet(x, y) val_dataloader = data.DataLoader(val_dataset, batch_size) acc_list = [] loss_list = [] for idx, (batch_x, batch_y) in enumerate(val_dataloader): if idx == 0 or idx == len(val_dataloader) - 1: self.init_graph_out_tensor(batch_x, batch_y) y_pred = self.forward(batch_x) metric = self.loss.metric(y_pred, batch_y) acc_list.append(metric[0]) loss_list.append(metric[1]) acc = GLOBAL.np.array(acc_list).mean().tolist() loss = GLOBAL.np.array(loss_list).mean().tolist() else: y_pred = self.forward(F.Tensor(x)) acc, loss = self.loss.metric(y_pred, F.Tensor(y)) return acc, loss
def select_batch(dataset: data.Dataset, size: int, generator: torch.Generator = None): dataloader = data.DataLoader(dataset, batch_size=size, shuffle=True, generator=generator) return next(iter(dataloader))
def predict(self, x: GLOBAL.np.ndarray, batch_size: int = None): self.eval() if batch_size is not None: assert type(batch_size) is int test_dataset = data.DataSet(x) test_dataloader = data.DataLoader(test_dataset, batch_size) pred_list = [] for batch_x in test_dataloader: y_pred = self.forward(batch_x) pred_list.append(y_pred) pred = F.concat(pred_list, axis=0) else: pred = self.forward(F.Tensor(x)) return pred
def build_dataloader(self, datasets: T.Dict[str, data.Dataset], n_support: int = None, n_queries: int = None, generator: torch.Generator = None): n_support = n_support or self.n_support n_queries = n_queries or self.n_queries generator = generator or self.generator self.last_datasets = datasets dslist = list(datasets.values()) target_len = max([len(ds) for ds in dslist]) for i, dataset in enumerate(dslist): if len(dataset) < target_len: new_dataset = dataset + dataset while len(new_dataset) < target_len: new_dataset += dataset dslist[i] = new_dataset queries = dslist[0] labels = [0] * len(dslist[0]) for i in range(1, len(dslist)): queries += dslist[i] labels += [i] * len(dslist[i]) # Select supports # Will be constant for all queries supports = [] for dataset in dslist: support = self.select_batch(dataset, size=n_support, generator=generator) supports.append(support) def collate_fn(batch): queries = [] labels = [] for row in batch: query, label = row queries.append(query) labels.append(label) queries = torch.stack(queries, dim=0) labels = torch.tensor(labels, device=queries.device) return queries, labels, *supports dataset = data.dzip(queries, labels) return data.DataLoader(dataset, batch_size=self.n_queries, collate_fn=collate_fn, shuffle=True, generator=generator)
def build_dataloader(datasets: T.Dict[str, data.Dataset], n_support: int, n_queries: int, generator: torch.Generator = None): dslist = list(datasets.values()) queries = dslist[0] labels = [0] * len(dslist[0]) for i in range(1, len(dslist)): queries += dslist[i] labels += [i] * len(dslist[i]) # Select supports # Will be constant for all queries supports = [] for dataset in dslist: support = select_batch(dataset, size=n_support, generator=generator) supports.append(support) def collate_fn(batch): queries = [] labels = [] for row in batch: query, label = row queries.append(query) labels.append(label) queries = torch.stack(queries, dim=0) labels = torch.tensor(labels, device=queries.device) return queries, labels, *supports dataset = data.dzip(queries, labels) return data.DataLoader(dataset, batch_size=n_queries, collate_fn=collate_fn, shuffle=True, generator=generator)
def fewshot_dataloader(datasets: T.List[data.Dataset], n_support: int, n_queries: int, generator: torch.Generator = None, select_batch_fn=select_batch): target_len = max([len(ds) for ds in datasets]) for i, dataset in enumerate(datasets): if len(dataset) < target_len: new_dataset = dataset + dataset while len(new_dataset) < target_len: new_dataset += dataset datasets[i] = new_dataset queries = datasets[0] labels = [0] * len(datasets[0]) for i in range(1, len(datasets)): queries += datasets[i] labels += [i] * len(datasets[i]) # Select supports # Will be constant for all queries supports = [] for dataset in datasets: support = select_batch_fn(dataset, size=n_support, generator=generator) supports.append(support) def collate_fn(batch): queries = [] labels = [] for row in batch: query, label = row queries.append(query) labels.append(label) queries = torch.stack(queries, dim=0) labels = torch.tensor(labels, device=queries.device) return queries, labels, *supports dataset = data.dzip(queries, labels) return data.DataLoader(dataset, batch_size=n_queries, collate_fn=collate_fn, shuffle=True, generator=generator)
def fit(self, x: GLOBAL.np.ndarray, y: GLOBAL.np.ndarray, batch_size: int = None, epochs: int = 1, verbose: int = 1, shuffle: bool = True, validation_data: Tuple[GLOBAL.np.ndarray] = None, validation_split: float = 0., initial_epoch: int = 0, ) -> SummaryProfile: x = GLOBAL.np.asarray(x) y = GLOBAL.np.asarray(y) record_data_names = ['train_loss', 'train_acc'] if validation_data is None and 0. < validation_split < 1.: split = int(x.shape[0] * validation_split) valid_x, valid_y = x[-split:], y[-split:] train_x, train_y = x[:-split], y[:-split] validation_data = (valid_x, valid_y) record_data_names.append('validation_loss') record_data_names.append('validation_acc') else: train_x, train_y = x, y train_dataset = data.DataSet(train_x, train_y) train_dataloader = data.DataLoader(train_dataset, batch_size, shuffle) v_acc, v_loss = 0., 0. with SummaryProfile(*record_data_names) as profile: for epoch in range(initial_epoch, epochs): epoch_start_time = time.time() if verbose != 0: print('\033[1;31m Epoch[%d/%d]\033[0m' % (epoch + 1, epochs)) progress_bar = ProgressBar(max_iter=len(train_x), verbose=verbose) for idx, (batch_x, batch_y) in enumerate(train_dataloader): if GLOBAL.USE_CUDA: batch_x.cuda_() batch_y.cuda_() else: batch_x.cpu_() batch_y.cpu_() if idx == 0 or idx == len(train_dataloader) - 1: self.init_graph_out_tensor(batch_x, batch_y) self.train() # reset trainable_variables grad self.optimizer.zero_grad() # forward pred = self.forward(batch_x) loss = self.loss.forward(pred, batch_y) self.backward(loss) self.optimizer.step() epoch_time = time.time() - epoch_start_time train_acc = self.loss.calc_acc(pred, batch_y) profile.step('train_acc', train_acc) profile.step('train_loss', loss.item()) if validation_data is not None: valid_x, valid_y = validation_data v_acc, v_loss = self.evaluate(valid_x, valid_y, batch_size=batch_size) profile.step('validation_loss', v_loss) profile.step('validation_acc', v_acc) progress_bar.update(batch_x.shape[0]) progress_bar.console(verbose, epoch_time, loss.item(), train_acc, v_loss, v_acc) if verbose != 0: print() return profile