def print_task_info(lifelong_dataset): class_names_samples = {class_: 0 for class_ in lifelong_dataset.cur_task} for idx in range(len(lifelong_dataset)): labels = lifelong_dataset.get_labels(idx) for label in labels: if label in class_names_samples.keys(): class_names_samples[label] += 1 print_msg(f"Task {lifelong_dataset.cur_task_id} number of samples: {len(lifelong_dataset)}") for class_name, num_samples in class_names_samples.items(): print_msg(f"{class_name} is present in {num_samples} samples")
def _prepare_model_for_new_task(self, task_data: Dataset, dist_args: Optional[dict] = None, **kwargs) -> None: """ A method specific function that takes place before the starting epoch of each new task (runs from the prepare_model_for_task function). It copies the old network and freezes it's gradients. It also extends the output layer, imprints weights for those extended nodes, and change the trainable parameters Args: task_data (Dataset): The new task dataset dist_args (Optional[Dict]): a dictionary of the distributed processing values in case of multiple gpu (ex: rank of the device) (default: None) """ self.old_net = copy_freeze(self.net) self.old_net.eval() cur_task_id = self.cur_task_id num_old_classes = int(sum(self.n_cla_per_tsk[: cur_task_id])) num_new_classes = self.n_cla_per_tsk[cur_task_id] device = next(self.net.parameters()).device # Extend last layer if cur_task_id > 0: output_layer = cosine_linear.SplitCosineLinear(in_features=self.latent_dim, out_features1=num_old_classes, out_features2=num_new_classes, sigma=self.sigma).to(device) if cur_task_id == 1: output_layer.fc1.weight.data = self.net.model.output_layer.weight.data else: out_features1 = self.net.model.output_layer.fc1.out_features output_layer.fc1.weight.data[:out_features1] = self.net.model.output_layer.fc1.weight.data output_layer.fc1.weight.data[out_features1:] = self.net.model.output_layer.fc2.weight.data output_layer.sigma.data = self.net.model.output_layer.sigma.data self.net.model.output_layer = output_layer self.lambda_cur = self.lambda_base * math.sqrt(num_old_classes * 1.0 / num_new_classes) print_msg(f"Lambda for less forget is set to {self.lambda_cur}") elif cur_task_id != 0: raise ValueError("task id cannot be negative") # Imprint weights with task_data.disable_augmentations(): if cur_task_id > 0: print_msg("Imprinting weights") self.net = self._imprint_weights(task_data, self.net, dist_args) # Fix parameters of FC1 for less forget and reset optimizer/scheduler if cur_task_id > 0: trainable_parameters = [param for name, param in self.net.named_parameters() if "output_layer.fc1" not in name] else: trainable_parameters = self.net.parameters() self.reset_optimizer_and_scheduler(trainable_parameters)
def step_scheduler(self, val_metric: Optional = None) -> None: """ Take a step with the scheduler (should be called after each epoch) Args: val_metric (Optional): a metric to compare in case of reducing the learning rate on plateau (default: None) """ cur_lr = self.get_last_lr() if self.reduce_lr_on_plateau: assert val_metric is not None self.scheduler.step(val_metric) else: self.scheduler.step() new_lr = self.get_last_lr() if cur_lr != new_lr: print_msg(f"learning rate changes to {new_lr}")
def reset_optimizer_and_scheduler( self, optimizable_parameters: Optional[Iterator[ nn.parameter.Parameter]] = None ) -> None: """ Reset the optimizer and scheduler after a task is done (with the option to specify which parameters to optimize Args: optimizable_parameters (Optional[Iterator[nn.parameter.Parameter]]: specify the parameters that should be optimized, in case some parameters needs to be frozen (default: None) """ print_msg( f"resetting scheduler and optimizer, learning rate = {self.lr}") if optimizable_parameters is None: optimizable_parameters = self.net.parameters() self.opt, self.scheduler = get_optimizer( model_parameters=optimizable_parameters, optimizer_type=self.optimizer_type, lr=self.lr, lr_gamma=self.lr_gamma, lr_schedule=self.lr_schedule, reduce_lr_on_plateau=self.reduce_lr_on_plateau, weight_decay=self.weight_decay)
def main_worker(gpu, config: dict, dist_args: dict = None): distributed = dist_args is not None if distributed: dist_args["gpu"] = gpu device = torch.device(f"cuda:{gpu}") dist_args["rank"] = dist_args["node_rank"] * dist_args["ngpus_per_node"] + gpu rank = dist_args["rank"] print_msg(f"Using GPU {gpu} with rank {rank}") dist.init_process_group(backend="nccl", init_method=dist_args["dist_url"], world_size=dist_args["world_size"], rank=dist_args["rank"]) elif gpu is not None: device = torch.device(f"cuda:{gpu}") rank = 0 print_msg(f"Using GPU: {gpu}") else: device = config["device"] rank = 0 print_msg(f"using {config['device']}\n") checkpoint = None non_loadable_attributes = ["logging_path", "dataset_path", "batch_size"] temp = {key: val for key, val in config.items() if key in non_loadable_attributes} checkpoint_path = os.path.join(config['logging_path'], 'latest_model') json_logs_file_name = 'jsonlogs.jsonl' if os.path.isfile(checkpoint_path): logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"), filemode='a+', format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=logging.INFO) if distributed: print_msg(f"\n\nLoading checkpoint {checkpoint_path} on gpu {dist_args['rank']}") checkpoint = torch.load(checkpoint_path, map_location=f"cuda:{dist_args['gpu']}") else: print_msg(f"\n\nLoading checkpoint {checkpoint_path}") checkpoint = torch.load(checkpoint_path) config = checkpoint['config'] for key in non_loadable_attributes: config[key] = temp[key] if distributed: print_msg(f"Loaded the checkpoint successfully on gpu {dist_args['rank']}") else: print_msg(f"Loaded the checkpoint successfully") if rank == 0: print_msg(f"Resuming from task {config['cur_task_id']} epoch {config['task_epoch']}") # Remove logs related to traing after the checkpoint was saved utils.remove_extra_logs(config['cur_task_id'], config['task_epoch'], os.path.join(config['logging_path'], json_logs_file_name)) if distributed: dist.barrier() else: dist.barrier() else: if rank == 0: os.makedirs(config['logging_path'], exist_ok=True) if os.path.isfile(os.path.join(config['logging_path'], json_logs_file_name)): os.remove(os.path.join(config['logging_path'], json_logs_file_name)) logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"), filemode='w', format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=logging.INFO) if distributed: dist.barrier() else: dist.barrier() logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"), filemode='a', format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=logging.INFO) torch.random.manual_seed(config['seed']) np.random.seed(config['seed']) random.seed(config["seed"]) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False if config["wandb_log"]: wandb_config = dict(project=config["wandb_project"], config=config, allow_val_change=True, name=config["id"], id=config["id"]) else: wandb_config = None logbook_config = ml_logbook.make_config( logger_dir=config['logging_path'], filename=json_logs_file_name, create_multiple_log_files=False, wandb_config=wandb_config, ) logbook = ml_logbook.LogBook(config=logbook_config) if checkpoint is None and (not distributed or dist_args['rank'] == 0): config_to_write = {str(key): str(value) for key, value in config.items()} logbook.write_config(config_to_write) essential_transforms_fn, augmentation_transforms_fn = get_transforms(config['dataset']) lifelong_datasets, tasks, class_names_to_idx = \ get_lifelong_datasets(config['dataset'], dataset_root=config['dataset_path'], tasks_configuration_id=config["tasks_configuration_id"], essential_transforms_fn=essential_transforms_fn, augmentation_transforms_fn=augmentation_transforms_fn, cache_images=False, joint=config["joint"]) if config["complete_info"]: for lifelong_dataset in lifelong_datasets.values(): lifelong_dataset.enable_complete_information_mode() if checkpoint is None: n_cla_per_tsk = [len(task) for task in tasks] metadata = {} metadata['n_tasks'] = len(tasks) metadata["total_num_classes"] = len(class_names_to_idx) metadata["tasks"] = tasks metadata["class_names_to_idx"] = class_names_to_idx metadata["n_cla_per_tsk"] = n_cla_per_tsk if rank == 0: metadata_to_write = {str(key): str(value) for key, value in metadata.items()} logbook.write_metadata(metadata_to_write) else: metadata = checkpoint['metadata'] # Assert that methods files lie in the folder "methods" method = importlib.import_module('lifelong_methods.methods.' + config["method"]) model = method.Model(metadata["n_cla_per_tsk"], metadata["class_names_to_idx"], config) buffer_dir = None map_size = None if "imagenet" in config["dataset"]: if config['n_memories_per_class'] > 0: n_classes = sum(metadata["n_cla_per_tsk"]) buffer_dir = config['logging_path'] map_size = int(config['n_memories_per_class'] * n_classes * 1.4e6) elif config['total_n_memories'] > 0: buffer_dir = config['logging_path'] map_size = int(config['total_n_memories'] * 1.4e6) buffer = method.Buffer(config, buffer_dir, map_size, essential_transforms_fn, augmentation_transforms_fn) if gpu is not None: torch.cuda.set_device(gpu) model.to(device) model.net = torch.nn.parallel.DistributedDataParallel(model.net, device_ids=[gpu]) else: model.to(config["device"]) # If loading a checkpoint, load the corresponding state_dicts if checkpoint is not None: lifelong_methods.utils.load_model(checkpoint, model, buffer, lifelong_datasets) print_msg(f"Loaded the state dicts successfully") starting_task = config["cur_task_id"] else: starting_task = 0 for cur_task_id in range(starting_task, len(tasks)): if checkpoint is not None and cur_task_id == starting_task and config["task_epoch"] > 0: new_task_starting = False else: new_task_starting = True if config["incremental_joint"]: for lifelong_dataset in lifelong_datasets.values(): lifelong_dataset.load_tasks_up_to(cur_task_id) else: for lifelong_dataset in lifelong_datasets.values(): lifelong_dataset.choose_task(cur_task_id) if rank == 0: print_task_info(lifelong_datasets["train"]) model.net.eval() if new_task_starting: model.prepare_model_for_new_task(task_data=lifelong_datasets["train"], dist_args=dist_args, buffer=buffer, num_workers=config["num_workers"]) model.net.train() start_time = time.time() task_train(model, buffer, lifelong_datasets, config, metadata, logbook=logbook, dist_args=dist_args) end_time = time.time() print_msg(f"Time taken on device {rank} for training on task {cur_task_id}: " f"{round((end_time - start_time) / 60, 2)} mins") model.net.eval() buffer.update_buffer_new_task( lifelong_datasets["train"], dist_args=dist_args, model=model, batch_size=config["batch_size"] ) if distributed: dist.barrier() # Assert that all gpus share the same buffer (check the last image) if distributed and len(buffer) > 0: with buffer.disable_augmentations(): image_1, _, _ = buffer[-1] image_1 = image_1.to(device) if rank == 0: image_2, _, _ = buffer[-1] image_2 = image_2.to(device) else: image_2 = torch.empty_like(image_1) torch.distributed.broadcast(image_2, 0) if torch.all(image_1.eq(image_2)): print_msg(f"buffers are similar between rank {dist_args['rank']} and rank 0") model.consolidate_task_knowledge( buffer=buffer, device=device, batch_size=config["batch_size"] ) tasks_eval( model, lifelong_datasets["posttask_valid"], cur_task_id, config, metadata, logbook=logbook, dataset_type="valid", dist_args=dist_args ) tasks_eval( model, lifelong_datasets["test"], cur_task_id, config, metadata, logbook=logbook, dataset_type="test", dist_args=dist_args ) config["cur_task_id"] = cur_task_id + 1 config["task_epoch"] = 0 if config["save_each_task_model"] and rank == 0: save_file = os.path.join(config['logging_path'], f"task_{cur_task_id}_model") lifelong_methods.utils.save_model(save_file, config, metadata, model, buffer, lifelong_datasets) if rank == 0: print_msg("Saving checkpoint") save_file = os.path.join(config['logging_path'], "latest_model") lifelong_methods.utils.save_model(save_file, config, metadata, model, buffer, lifelong_datasets)
parser.add_argument('--num_nodes', type=int, default=1, help="num of nodes to use") parser.add_argument('--node_rank', default=0, type=int, help='node rank for distributed training') parser.add_argument('--multiprocessing_distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs.') parser.add_argument('--dist_url', default="env://", type=str, help='node rank for distributed training') args = parser.parse_args() config = prepare_config(args) if "iirc" in config["dataset"]: config["setup"] = IIRC_SETUP else: config["setup"] = CIL_SETUP config['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') config["ngpus_per_node"] = torch.cuda.device_count() print_msg(f"number of gpus per node: {config['ngpus_per_node']}") if args.multiprocessing_distributed or args.num_nodes > 1: dist_args = {"num_nodes": args.num_nodes, "node_rank": args.node_rank} dist_args["ngpus_per_node"] = config["ngpus_per_node"] dist_args["world_size"] = dist_args["ngpus_per_node"] * dist_args["num_nodes"] config["batch_size"] = int(ceil(config["batch_size"] / dist_args["world_size"])) dist_args["dist_url"] = args.dist_url mp.spawn(main_worker, nprocs=dist_args["ngpus_per_node"], args=(config, dist_args)) else: main_worker(None, config, None)
def _construct_exemplar_set(self, task_data: Dataset, dist_args: Optional[dict] = None, model: torch.nn.Module = None, batch_size=1, **kwargs) -> None: """ Update the buffer with the new task samples using herding Args: task_data (Dataset): The new task data dist_args (Optional[Dict]): a dictionary of the distributed processing values in case of multiple gpu (ex: rank of the device) (default: None) model (BaseMethod): The current method object to calculate the latent variables batch_size (int): The minibatch size """ distributed = dist_args is not None if distributed: device = torch.device(f"cuda:{dist_args['gpu']}") rank = dist_args['rank'] else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') rank = 0 new_class_labels = task_data.cur_task model.eval() print_msg(f"Adding buffer samples") #### with task_data.disable_augmentations(): # disable augmentations then enable them (if they were already enabled) with torch.no_grad(): for class_label in new_class_labels: class_data_indices = task_data.get_image_indices_by_cla(class_label, self.max_mems_pool_size) if distributed: device = torch.device(f"cuda:{dist_args['gpu']}") class_data_indices_to_broadcast = torch.from_numpy(class_data_indices).to(device) torch.distributed.broadcast(class_data_indices_to_broadcast, 0) class_data_indices = class_data_indices_to_broadcast.cpu().numpy() sampler = SubsetSampler(class_data_indices) class_loader = DataLoader(task_data, batch_size=batch_size, sampler=sampler) latent_vectors = [] for minibatch in class_loader: images = minibatch[0].to(device) output, out_latent = model.forward_net(images) out_latent = out_latent.detach() out_latent = F.normalize(out_latent, p=2, dim=-1) latent_vectors.append(out_latent) latent_vectors = torch.cat(latent_vectors, dim=0) class_mean = torch.mean(latent_vectors, dim=0) chosen_exemplars_ind = [] exemplars_mean = torch.zeros_like(class_mean) while len(chosen_exemplars_ind) < min(self.n_mems_per_cla, len(class_data_indices)): potential_exemplars_mean = (exemplars_mean.unsqueeze(0) * len( chosen_exemplars_ind) + latent_vectors) \ / (len(chosen_exemplars_ind) + 1) distance = (class_mean.unsqueeze(0) - potential_exemplars_mean).norm(dim=-1) shuffled_index = torch.argmin(distance).item() exemplars_mean = potential_exemplars_mean[shuffled_index, :].clone() exemplar_index = class_data_indices[shuffled_index] chosen_exemplars_ind.append(exemplar_index) latent_vectors[shuffled_index, :] = float("inf") for image_index in chosen_exemplars_ind: image, label1, label2 = task_data.get_item(image_index) if label2 != NO_LABEL_PLACEHOLDER: warnings.warn(f"Sample is being added to the buffer with labels {label1} and {label2}") self.add_sample(class_label, image, (label1, label2), rank=rank)
def tasks_eval(model, dataset, cur_task_id, config, metadata, logbook, dataset_type="valid", dist_args=None): """log the accuracies of the new model on all observed tasks :param metadata: """ assert dataset.complete_information_mode is True distributed = dist_args is not None if distributed: gpu = dist_args["gpu"] rank = dist_args["rank"] else: gpu = None rank = 0 metrics_dict = {} for task_id in range(cur_task_id + 1): dataset.choose_task(task_id) dataloader = data.DataLoader(dataset, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], pin_memory=True) _, metrics = evaluate(model, dataloader, config, metadata, test_mode=True, gpu=gpu) for metric in metrics.keys(): metrics_dict[f"task_{task_id}_{dataset_type}_{metric}"] = metrics[ metric] dataset.load_tasks_up_to(cur_task_id) dataloader = data.DataLoader(dataset, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], pin_memory=True) _, metrics = evaluate(model, dataloader, config, metadata, test_mode=True, gpu=gpu) for metric in metrics.keys(): metrics_dict[f"average_{dataset_type}_{metric}"] = metrics[metric] if rank == 0: utils.log_task(cur_task_id, metrics_dict, logbook) if distributed: dist.barrier() metrics_dict["rank"] = rank print_msg(metrics_dict) else: dist.barrier() metrics_dict["rank"] = rank print_msg(metrics_dict)
def task_train(model, buffer, lifelong_datasets, config, metadata, logbook, dist_args=None): distributed = dist_args is not None if distributed: gpu = dist_args["gpu"] rank = dist_args["rank"] else: gpu = None rank = 0 best_checkpoint = { "model_state_dict": deepcopy(model.method_state_dict()), "best_modified_jaccard": 0 } best_checkpoint_file = os.path.join(config['logging_path'], "best_checkpoint") if config['use_best_model']: if config['task_epoch'] > 0 and os.path.exists(best_checkpoint_file): if distributed: best_checkpoint = torch.load( best_checkpoint_file, map_location=f"cuda:{dist_args['gpu']}") else: best_checkpoint = torch.load(best_checkpoint_file) task_train_data = lifelong_datasets['train'] if config["method"] == "agem": bsm = 0.0 task_train_data_with_buffer = TaskDataMergedWithBuffer( buffer, task_train_data, buffer_sampling_multiplier=bsm) else: bsm = config["buffer_sampling_multiplier"] task_train_data_with_buffer = TaskDataMergedWithBuffer( buffer, task_train_data, buffer_sampling_multiplier=bsm) task_valid_data = lifelong_datasets['intask_valid'] cur_task_id = task_train_data.cur_task_id if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( task_train_data_with_buffer, num_replicas=dist_args["world_size"], rank=rank) else: train_sampler = None train_loader = data.DataLoader(task_train_data_with_buffer, batch_size=config["batch_size"], shuffle=(train_sampler is None), num_workers=config["num_workers"], pin_memory=True, sampler=train_sampler) valid_loader = data.DataLoader(task_valid_data, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"], pin_memory=True) if cur_task_id == 0: num_epochs = config['epochs_per_task'] * 2 print_msg( f"Training for {num_epochs} epochs for the first task (double that the other tasks)" ) else: num_epochs = config['epochs_per_task'] print_msg( f"Starting training of task {cur_task_id} epoch {config['task_epoch']} till epoch {num_epochs}" ) for epoch in range(config['task_epoch'], num_epochs): if distributed: train_sampler.set_epoch(epoch) start_time = time.time() log_dict = {} train_loss, train_metrics = epoch_train(model, train_loader, config, metadata, gpu, rank) log_dict[f"train_loss_{cur_task_id}"] = train_loss for metric in train_metrics.keys(): log_dict[f"train_{metric}_{cur_task_id}"] = train_metrics[metric] valid_loss, valid_metrics = evaluate(model, valid_loader, config, metadata, test_mode=False, gpu=gpu) log_dict[f"valid_loss_{cur_task_id}"] = valid_loss for metric in valid_metrics.keys(): log_dict[f"valid_{metric}_{cur_task_id}"] = valid_metrics[metric] model.net.eval() model.consolidate_epoch_knowledge( log_dict[f"valid_modified_jaccard_{cur_task_id}"], task_data=task_train_data, device=config["device"], batch_size=config["batch_size"]) # If using the lmdb database, close it and open a new environment to kill active readers buffer.reset_lmdb_database() if config['use_best_model']: if log_dict[ f"valid_modified_jaccard_{cur_task_id}"] >= best_checkpoint[ "best_modified_jaccard"]: best_checkpoint["best_modified_jaccard"] = log_dict[ f"valid_modified_jaccard_{cur_task_id}"] best_checkpoint["model_state_dict"] = deepcopy( model.method_state_dict()) log_dict[ f"best_valid_modified_jaccard_{cur_task_id}"] = best_checkpoint[ "best_modified_jaccard"] if distributed: dist.barrier() #to calculate the time based on the slowest gpu end_time = time.time() log_dict[f"elapsed_time"] = round(end_time - start_time, 2) if rank == 0: utils.log(epoch, cur_task_id, log_dict, logbook) if distributed: dist.barrier() log_dict["rank"] = rank print_msg(log_dict) else: dist.barrier() log_dict["rank"] = rank print_msg(log_dict) # Checkpointing config["task_epoch"] = epoch + 1 if (config["task_epoch"] % config['checkpoint_interval']) == 0 and rank == 0: print_msg("Saving latest checkpoint") save_file = os.path.join(config['logging_path'], "latest_model") lifelong_methods.utils.save_model(save_file, config, metadata, model, buffer, lifelong_datasets) if config['use_best_model']: print_msg("Saving best checkpoint") torch.save(best_checkpoint, best_checkpoint_file) # reset the model parameters to the best performing model if config['use_best_model']: model.load_method_state_dict(best_checkpoint["model_state_dict"])
def epoch_train(model, dataloader, config, metadata, gpu=None, rank=0): train_loss = 0 train_metrics = { 'jaccard_sim': 0., 'modified_jaccard': 0., 'strict_acc': 0., 'recall': 0. } data_len = 0 class_names_to_idx = metadata["class_names_to_idx"] num_seen_classes = len(model.seen_classes) model.net.train() minibatch_i = 0 for minibatch in dataloader: labels_names = list(zip(minibatch[1], minibatch[2])) labels = transform_labels_names_to_vector(labels_names, num_seen_classes, class_names_to_idx) if gpu is None: images = minibatch[0].to(config["device"], non_blocking=True) labels = labels.to(config["device"], non_blocking=True) else: images = minibatch[0].to(torch.device(f"cuda:{gpu}"), non_blocking=True) labels = labels.to(torch.device(f"cuda:{gpu}"), non_blocking=True) if len(minibatch) > 3: if gpu is None: in_buffer = minibatch[3].to(config["device"], non_blocking=True) else: in_buffer = minibatch[3].to(torch.device(f"cuda:{gpu}"), non_blocking=True) else: in_buffer = None predictions, loss = model.observe(images, labels, in_buffer, train=True) labels = labels.bool() train_loss += loss * images.shape[0] train_metrics['jaccard_sim'] += metrics.jaccard_sim( predictions, labels) * images.shape[0] train_metrics['modified_jaccard'] += metrics.modified_jaccard_sim( predictions, labels) * images.shape[0] train_metrics['strict_acc'] += metrics.strict_accuracy( predictions, labels) * images.shape[0] train_metrics['recall'] += metrics.recall(predictions, labels) * images.shape[0] data_len += images.shape[0] if minibatch_i == 0: print_msg( f"rank {rank}, max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB, " f"current memory allocated: {torch.cuda.memory_allocated() / 1024 / 1024} MB\n" ) minibatch_i += 1 train_loss /= data_len train_metrics['jaccard_sim'] /= data_len train_metrics['modified_jaccard'] /= data_len train_metrics['strict_acc'] /= data_len train_metrics['recall'] /= data_len return train_loss, train_metrics
def get_lifelong_datasets( dataset_name: str, dataset_root: str = "./data", setup: str = IIRC_SETUP, framework: str = PYTORCH, tasks_configuration_id: int = 0, essential_transforms_fn: Optional[Callable[[Image.Image], Any]] = None, augmentation_transforms_fn: Optional[Callable[[Image.Image], Any]] = None, cache_images: bool = False, joint: bool = False ) -> Tuple[Dict[str, BaseDataset], List[List[str]], Dict[str, int]]: """ Get the incremental refinement learning , as well as the tasks (which contains the classes introduced at each task), and the index for each class corresponding to its order of appearance Args: dataset_name (str): The name of the dataset, ex: iirc_cifar100 dataset_root (str): The directory where the dataset is/will be downloaded (default: "./data") setup (str): Class Incremental Learning setup (CIL) or Incremental Implicitly Refined Classification setup (IIRC) (default: IIRC_SETUP) framework (str): The framework to be used, whether PyTorch or Tensorflow. use Tensorflow for any numpy based dataloading (default: PYTORCH) tasks_configuration_id (int): The configuration id, where each configuration corresponds to a specific tasks and classes order for each dataset. This id starts from 0 for each dataset. Ignore when joint is set to True (default: 0) essential_transforms_fn (Optional[Callable[[Image.Image], Any]]): A function that contains the essential transforms (for example, converting a pillow image to a tensor) that should be applied to each image. This function is applied only when the augmentation_transforms_fn is set to None (as in the case of a test set) or inside the disable_augmentations context (default: None) augmentation_transforms_fn: A function that contains the essential transforms (for example, converting a pillow image to a tensor) and augmentation transforms (for example, applying random cropping) that should be applied to each image. When this function is provided, essential_transforms_fn is not used except inside the disable_augmentations context (default: None) cache_images (bool): cache images that belong to the current task in the memory, only applicable when using the image path (default: False) joint (bool): provided all the classes in a single task for joint training (default: False) Returns: Tuple[Dict[str, BaseDataset], List[List[str]], Dict[str, int]]: lifelong_datasets (Dict[str, BaseDataset]): a dictionary with the keys corresponding to the four splits (train, intask_validation, posttask_validation, test), and the values containing the dataset object inheriting from BaseDataset for that split. tasks (List[List[str]]): a list of lists where each inner list contains the set of classes (class names) that will be introduced in that task (example: [[dog, cat, car], [tiger, truck, fish]]). class_names_to_idx (Dict[str, int]): a dictionary with the class name as key, and the class index as value (example: {"dog": 0, "cat": 1, ...}). """ assert framework in [ PYTORCH, TENSORFLOW ], f'The framework is set to neither "{PYTORCH}" nor "{TENSORFLOW}"' assert setup in [ IIRC_SETUP, CIL_SETUP ], f'The setup is set to neither "{IIRC_SETUP}" nor "{CIL_SETUP}"' assert dataset_name in datasets_names, f'The dataset_name is not in {datasets_names}' print_msg(f"Creating {dataset_name}") datasets, dataset_configuration = \ _get_dataset(dataset_name=dataset_name, dataset_root=dataset_root) tasks, class_names_to_idx = _get_tasks_configuration( dataset_name, tasks_configuration_id=tasks_configuration_id, joint=joint) sprcla_data_pct = dataset_configuration["superclass_data_pct"] subcla_data_pct = dataset_configuration["subclass_data_pct"] using_image_path = dataset_configuration["using_image_path"] sprcla_sampling_size_cap = dataset_configuration[ "superclass_sampling_size_cap"] lifelong_datasets = {} if framework == PYTORCH: from iirc.lifelong_dataset.torch_dataset import Dataset LifeLongDataset = Dataset elif framework == TENSORFLOW: from iirc.lifelong_dataset.tensorflow_dataset import Dataset LifeLongDataset = Dataset else: raise NotImplementedError print_msg(f"Setup used: {setup}\nUsing {framework}") shared_arguments = dict( tasks=tasks, setup=setup, using_image_path=using_image_path, cache_images=cache_images, essential_transforms_fn=essential_transforms_fn, superclass_data_pct=sprcla_data_pct, subclass_data_pct=subcla_data_pct, superclass_sampling_size_cap=sprcla_sampling_size_cap) lifelong_datasets["train"] = LifeLongDataset( dataset=datasets["train"], test_mode=False, augmentation_transforms_fn=augmentation_transforms_fn, **shared_arguments) lifelong_datasets["intask_valid"] = LifeLongDataset( dataset=datasets["intask_valid"], test_mode=False, **shared_arguments) lifelong_datasets["posttask_valid"] = LifeLongDataset( dataset=datasets["posttask_valid"], test_mode=True, **shared_arguments) lifelong_datasets["test"] = LifeLongDataset(dataset=datasets["test"], test_mode=True, **shared_arguments) print_msg("Dataset created") return lifelong_datasets, tasks, class_names_to_idx