Example #1
0
def print_task_info(lifelong_dataset):
    class_names_samples = {class_: 0 for class_ in lifelong_dataset.cur_task}
    for idx in range(len(lifelong_dataset)):
        labels = lifelong_dataset.get_labels(idx)
        for label in labels:
            if label in class_names_samples.keys():
                class_names_samples[label] += 1
    print_msg(f"Task {lifelong_dataset.cur_task_id} number of samples: {len(lifelong_dataset)}")
    for class_name, num_samples in class_names_samples.items():
        print_msg(f"{class_name} is present in {num_samples} samples")
Example #2
0
    def _prepare_model_for_new_task(self, task_data: Dataset, dist_args: Optional[dict] = None,
                                    **kwargs) -> None:
        """
        A method specific function that takes place before the starting epoch of each new task (runs from the
        prepare_model_for_task function).
        It copies the old network and freezes it's gradients.
        It also extends the output layer, imprints weights for those extended nodes, and change the trainable parameters

        Args:
            task_data (Dataset): The new task dataset
            dist_args (Optional[Dict]): a dictionary of the distributed processing values in case of multiple gpu (ex:
            rank of the device) (default: None)
        """
        self.old_net = copy_freeze(self.net)
        self.old_net.eval()

        cur_task_id = self.cur_task_id
        num_old_classes = int(sum(self.n_cla_per_tsk[: cur_task_id]))
        num_new_classes = self.n_cla_per_tsk[cur_task_id]
        device = next(self.net.parameters()).device

        # Extend last layer
        if cur_task_id > 0:
            output_layer = cosine_linear.SplitCosineLinear(in_features=self.latent_dim,
                                                           out_features1=num_old_classes,
                                                           out_features2=num_new_classes,
                                                           sigma=self.sigma).to(device)
            if cur_task_id == 1:
                output_layer.fc1.weight.data = self.net.model.output_layer.weight.data
            else:
                out_features1 = self.net.model.output_layer.fc1.out_features
                output_layer.fc1.weight.data[:out_features1] = self.net.model.output_layer.fc1.weight.data
                output_layer.fc1.weight.data[out_features1:] = self.net.model.output_layer.fc2.weight.data
            output_layer.sigma.data = self.net.model.output_layer.sigma.data
            self.net.model.output_layer = output_layer
            self.lambda_cur = self.lambda_base * math.sqrt(num_old_classes * 1.0 / num_new_classes)
            print_msg(f"Lambda for less forget is set to {self.lambda_cur}")
        elif cur_task_id != 0:
            raise ValueError("task id cannot be negative")

        # Imprint weights
        with task_data.disable_augmentations():
            if cur_task_id > 0:
                print_msg("Imprinting weights")
                self.net = self._imprint_weights(task_data, self.net, dist_args)

        # Fix parameters of FC1 for less forget and reset optimizer/scheduler
        if cur_task_id > 0:
            trainable_parameters = [param for name, param in self.net.named_parameters() if
                                    "output_layer.fc1" not in name]
        else:
            trainable_parameters = self.net.parameters()
        self.reset_optimizer_and_scheduler(trainable_parameters)
Example #3
0
    def step_scheduler(self, val_metric: Optional = None) -> None:
        """
        Take a step with the scheduler (should be called after each epoch)

        Args:
            val_metric (Optional): a metric to compare in case of reducing the learning rate on plateau (default: None)
        """
        cur_lr = self.get_last_lr()
        if self.reduce_lr_on_plateau:
            assert val_metric is not None
            self.scheduler.step(val_metric)
        else:
            self.scheduler.step()
        new_lr = self.get_last_lr()
        if cur_lr != new_lr:
            print_msg(f"learning rate changes to {new_lr}")
Example #4
0
    def reset_optimizer_and_scheduler(
        self,
        optimizable_parameters: Optional[Iterator[
            nn.parameter.Parameter]] = None
    ) -> None:
        """
        Reset the optimizer and scheduler after a task is done (with the option to specify which parameters to optimize

        Args:
            optimizable_parameters (Optional[Iterator[nn.parameter.Parameter]]: specify the parameters that should be
                optimized, in case some parameters needs to be frozen (default: None)
        """
        print_msg(
            f"resetting scheduler and optimizer, learning rate = {self.lr}")
        if optimizable_parameters is None:
            optimizable_parameters = self.net.parameters()
        self.opt, self.scheduler = get_optimizer(
            model_parameters=optimizable_parameters,
            optimizer_type=self.optimizer_type,
            lr=self.lr,
            lr_gamma=self.lr_gamma,
            lr_schedule=self.lr_schedule,
            reduce_lr_on_plateau=self.reduce_lr_on_plateau,
            weight_decay=self.weight_decay)
Example #5
0
def main_worker(gpu, config: dict, dist_args: dict = None):
    distributed = dist_args is not None
    if distributed:
        dist_args["gpu"] = gpu
        device = torch.device(f"cuda:{gpu}")
        dist_args["rank"] = dist_args["node_rank"] * dist_args["ngpus_per_node"] + gpu
        rank = dist_args["rank"]
        print_msg(f"Using GPU {gpu} with rank {rank}")
        dist.init_process_group(backend="nccl", init_method=dist_args["dist_url"],
                                world_size=dist_args["world_size"], rank=dist_args["rank"])
    elif gpu is not None:
        device = torch.device(f"cuda:{gpu}")
        rank = 0
        print_msg(f"Using GPU: {gpu}")
    else:
        device = config["device"]
        rank = 0
        print_msg(f"using {config['device']}\n")

    checkpoint = None
    non_loadable_attributes = ["logging_path", "dataset_path", "batch_size"]
    temp = {key: val for key, val in config.items() if key in non_loadable_attributes}
    checkpoint_path = os.path.join(config['logging_path'], 'latest_model')
    json_logs_file_name = 'jsonlogs.jsonl'
    if os.path.isfile(checkpoint_path):
        logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"),
                            filemode='a+',
                            format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S',
                            level=logging.INFO)
        if distributed:
            print_msg(f"\n\nLoading checkpoint {checkpoint_path} on gpu {dist_args['rank']}")
            checkpoint = torch.load(checkpoint_path, map_location=f"cuda:{dist_args['gpu']}")
        else:
            print_msg(f"\n\nLoading checkpoint {checkpoint_path}")
            checkpoint = torch.load(checkpoint_path)
        config = checkpoint['config']
        for key in non_loadable_attributes:
            config[key] = temp[key]
        if distributed:
            print_msg(f"Loaded the checkpoint successfully on gpu {dist_args['rank']}")
        else:
            print_msg(f"Loaded the checkpoint successfully")

        if rank == 0:
            print_msg(f"Resuming from task {config['cur_task_id']} epoch {config['task_epoch']}")
            # Remove logs related to traing after the checkpoint was saved
            utils.remove_extra_logs(config['cur_task_id'], config['task_epoch'],
                                    os.path.join(config['logging_path'], json_logs_file_name))
            if distributed:
                dist.barrier()
        else:
            dist.barrier()
    else:
        if rank == 0:
            os.makedirs(config['logging_path'], exist_ok=True)
            if os.path.isfile(os.path.join(config['logging_path'], json_logs_file_name)):
                os.remove(os.path.join(config['logging_path'], json_logs_file_name))
            logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"),
                                filemode='w',
                                format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                                datefmt='%H:%M:%S',
                                level=logging.INFO)
            if distributed:
                dist.barrier()
        else:
            dist.barrier()
            logging.basicConfig(filename=os.path.join(config['logging_path'], "logs.txt"),
                                filemode='a',
                                format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                                datefmt='%H:%M:%S',
                                level=logging.INFO)

    torch.random.manual_seed(config['seed'])
    np.random.seed(config['seed'])
    random.seed(config["seed"])

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    if config["wandb_log"]:
        wandb_config = dict(project=config["wandb_project"], config=config, allow_val_change=True,
                            name=config["id"], id=config["id"])
    else:
        wandb_config = None

    logbook_config = ml_logbook.make_config(
        logger_dir=config['logging_path'],
        filename=json_logs_file_name,
        create_multiple_log_files=False,
        wandb_config=wandb_config,
    )
    logbook = ml_logbook.LogBook(config=logbook_config)
    if checkpoint is None and (not distributed or dist_args['rank'] == 0):
        config_to_write = {str(key): str(value) for key, value in config.items()}
        logbook.write_config(config_to_write)

    essential_transforms_fn, augmentation_transforms_fn = get_transforms(config['dataset'])
    lifelong_datasets, tasks, class_names_to_idx = \
        get_lifelong_datasets(config['dataset'], dataset_root=config['dataset_path'],
                              tasks_configuration_id=config["tasks_configuration_id"],
                              essential_transforms_fn=essential_transforms_fn,
                              augmentation_transforms_fn=augmentation_transforms_fn, cache_images=False,
                              joint=config["joint"])

    if config["complete_info"]:
        for lifelong_dataset in lifelong_datasets.values():
            lifelong_dataset.enable_complete_information_mode()

    if checkpoint is None:
        n_cla_per_tsk = [len(task) for task in tasks]
        metadata = {}
        metadata['n_tasks'] = len(tasks)
        metadata["total_num_classes"] = len(class_names_to_idx)
        metadata["tasks"] = tasks
        metadata["class_names_to_idx"] = class_names_to_idx
        metadata["n_cla_per_tsk"] = n_cla_per_tsk
        if rank == 0:
            metadata_to_write = {str(key): str(value) for key, value in metadata.items()}
            logbook.write_metadata(metadata_to_write)
    else:
        metadata = checkpoint['metadata']

    # Assert that methods files lie in the folder "methods"
    method = importlib.import_module('lifelong_methods.methods.' + config["method"])
    model = method.Model(metadata["n_cla_per_tsk"], metadata["class_names_to_idx"], config)

    buffer_dir = None
    map_size = None
    if "imagenet" in config["dataset"]:
        if config['n_memories_per_class'] > 0:
            n_classes = sum(metadata["n_cla_per_tsk"])
            buffer_dir = config['logging_path']
            map_size = int(config['n_memories_per_class'] * n_classes * 1.4e6)
        elif config['total_n_memories'] > 0:
            buffer_dir = config['logging_path']
            map_size = int(config['total_n_memories'] * 1.4e6)
    buffer = method.Buffer(config, buffer_dir, map_size, essential_transforms_fn, augmentation_transforms_fn)

    if gpu is not None:
        torch.cuda.set_device(gpu)
        model.to(device)
        model.net = torch.nn.parallel.DistributedDataParallel(model.net, device_ids=[gpu])
    else:
        model.to(config["device"])

    # If loading a checkpoint, load the corresponding state_dicts
    if checkpoint is not None:
        lifelong_methods.utils.load_model(checkpoint, model, buffer, lifelong_datasets)
        print_msg(f"Loaded the state dicts successfully")
        starting_task = config["cur_task_id"]
    else:
        starting_task = 0

    for cur_task_id in range(starting_task, len(tasks)):
        if checkpoint is not None and cur_task_id == starting_task and config["task_epoch"] > 0:
            new_task_starting = False
        else:
            new_task_starting = True

        if config["incremental_joint"]:
            for lifelong_dataset in lifelong_datasets.values():
                lifelong_dataset.load_tasks_up_to(cur_task_id)
        else:
            for lifelong_dataset in lifelong_datasets.values():
                lifelong_dataset.choose_task(cur_task_id)

        if rank == 0:
            print_task_info(lifelong_datasets["train"])

        model.net.eval()
        if new_task_starting:
            model.prepare_model_for_new_task(task_data=lifelong_datasets["train"], dist_args=dist_args, buffer=buffer,
                                             num_workers=config["num_workers"])

        model.net.train()
        start_time = time.time()
        task_train(model, buffer, lifelong_datasets, config, metadata, logbook=logbook, dist_args=dist_args)
        end_time = time.time()
        print_msg(f"Time taken on device {rank} for training on task {cur_task_id}: "
                  f"{round((end_time - start_time) / 60, 2)} mins")

        model.net.eval()
        buffer.update_buffer_new_task(
            lifelong_datasets["train"], dist_args=dist_args, model=model, batch_size=config["batch_size"]
        )

        if distributed:
            dist.barrier()
        # Assert that all gpus share the same buffer (check the last image)
        if distributed and len(buffer) > 0:
            with buffer.disable_augmentations():
                image_1, _, _ = buffer[-1]
                image_1 = image_1.to(device)
                if rank == 0:
                    image_2, _, _ = buffer[-1]
                    image_2 = image_2.to(device)
                else:
                    image_2 = torch.empty_like(image_1)
                torch.distributed.broadcast(image_2, 0)
                if torch.all(image_1.eq(image_2)):
                    print_msg(f"buffers are similar between rank {dist_args['rank']} and rank 0")

        model.consolidate_task_knowledge(
            buffer=buffer, device=device, batch_size=config["batch_size"]
        )

        tasks_eval(
            model, lifelong_datasets["posttask_valid"], cur_task_id, config, metadata, logbook=logbook,
            dataset_type="valid", dist_args=dist_args
        )
        tasks_eval(
            model, lifelong_datasets["test"], cur_task_id, config, metadata, logbook=logbook, dataset_type="test",
            dist_args=dist_args
        )

        config["cur_task_id"] = cur_task_id + 1
        config["task_epoch"] = 0

        if config["save_each_task_model"] and rank == 0:
            save_file = os.path.join(config['logging_path'], f"task_{cur_task_id}_model")
            lifelong_methods.utils.save_model(save_file, config, metadata, model, buffer, lifelong_datasets)
        if rank == 0:
            print_msg("Saving checkpoint")
            save_file = os.path.join(config['logging_path'], "latest_model")
            lifelong_methods.utils.save_model(save_file, config, metadata, model, buffer, lifelong_datasets)
Example #6
0
    parser.add_argument('--num_nodes', type=int, default=1,
                        help="num of nodes to use")
    parser.add_argument('--node_rank', default=0, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--multiprocessing_distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                             'N processes per node, which has N GPUs.')
    parser.add_argument('--dist_url', default="env://", type=str,
                        help='node rank for distributed training')

    args = parser.parse_args()

    config = prepare_config(args)
    if "iirc" in config["dataset"]:
        config["setup"] = IIRC_SETUP
    else:
        config["setup"] = CIL_SETUP

    config['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config["ngpus_per_node"] = torch.cuda.device_count()
    print_msg(f"number of gpus per node: {config['ngpus_per_node']}")
    if args.multiprocessing_distributed or args.num_nodes > 1:
        dist_args = {"num_nodes": args.num_nodes, "node_rank": args.node_rank}
        dist_args["ngpus_per_node"] = config["ngpus_per_node"]
        dist_args["world_size"] = dist_args["ngpus_per_node"] * dist_args["num_nodes"]
        config["batch_size"] = int(ceil(config["batch_size"] / dist_args["world_size"]))
        dist_args["dist_url"] = args.dist_url
        mp.spawn(main_worker, nprocs=dist_args["ngpus_per_node"], args=(config, dist_args))
    else:
        main_worker(None, config, None)
Example #7
0
    def _construct_exemplar_set(self, task_data: Dataset, dist_args: Optional[dict] = None,
                                model: torch.nn.Module = None, batch_size=1, **kwargs) -> None:
        """
        Update the buffer with the new task samples using herding

        Args:
            task_data (Dataset): The new task data
            dist_args (Optional[Dict]): a dictionary of the distributed processing values in case of multiple gpu (ex:
            rank of the device) (default: None)
            model (BaseMethod): The current method object to calculate the latent variables
            batch_size (int): The minibatch size
        """
        distributed = dist_args is not None
        if distributed:
            device = torch.device(f"cuda:{dist_args['gpu']}")
            rank = dist_args['rank']
        else:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            rank = 0
        new_class_labels = task_data.cur_task
        model.eval()

        print_msg(f"Adding buffer samples")  ####
        with task_data.disable_augmentations():  # disable augmentations then enable them (if they were already enabled)
            with torch.no_grad():
                for class_label in new_class_labels:
                    class_data_indices = task_data.get_image_indices_by_cla(class_label, self.max_mems_pool_size)
                    if distributed:
                        device = torch.device(f"cuda:{dist_args['gpu']}")
                        class_data_indices_to_broadcast = torch.from_numpy(class_data_indices).to(device)
                        torch.distributed.broadcast(class_data_indices_to_broadcast, 0)
                        class_data_indices = class_data_indices_to_broadcast.cpu().numpy()
                    sampler = SubsetSampler(class_data_indices)
                    class_loader = DataLoader(task_data, batch_size=batch_size, sampler=sampler)
                    latent_vectors = []
                    for minibatch in class_loader:
                        images = minibatch[0].to(device)
                        output, out_latent = model.forward_net(images)
                        out_latent = out_latent.detach()
                        out_latent = F.normalize(out_latent, p=2, dim=-1)
                        latent_vectors.append(out_latent)
                    latent_vectors = torch.cat(latent_vectors, dim=0)
                    class_mean = torch.mean(latent_vectors, dim=0)

                    chosen_exemplars_ind = []
                    exemplars_mean = torch.zeros_like(class_mean)
                    while len(chosen_exemplars_ind) < min(self.n_mems_per_cla, len(class_data_indices)):
                        potential_exemplars_mean = (exemplars_mean.unsqueeze(0) * len(
                            chosen_exemplars_ind) + latent_vectors) \
                                                   / (len(chosen_exemplars_ind) + 1)
                        distance = (class_mean.unsqueeze(0) - potential_exemplars_mean).norm(dim=-1)
                        shuffled_index = torch.argmin(distance).item()
                        exemplars_mean = potential_exemplars_mean[shuffled_index, :].clone()
                        exemplar_index = class_data_indices[shuffled_index]
                        chosen_exemplars_ind.append(exemplar_index)
                        latent_vectors[shuffled_index, :] = float("inf")

                    for image_index in chosen_exemplars_ind:
                        image, label1, label2 = task_data.get_item(image_index)
                        if label2 != NO_LABEL_PLACEHOLDER:
                            warnings.warn(f"Sample is being added to the buffer with labels {label1} and {label2}")
                        self.add_sample(class_label, image, (label1, label2), rank=rank)
Example #8
0
def tasks_eval(model,
               dataset,
               cur_task_id,
               config,
               metadata,
               logbook,
               dataset_type="valid",
               dist_args=None):
    """log the accuracies of the new model on all observed tasks
    :param metadata:
    """
    assert dataset.complete_information_mode is True

    distributed = dist_args is not None
    if distributed:
        gpu = dist_args["gpu"]
        rank = dist_args["rank"]
    else:
        gpu = None
        rank = 0

    metrics_dict = {}
    for task_id in range(cur_task_id + 1):
        dataset.choose_task(task_id)
        dataloader = data.DataLoader(dataset,
                                     batch_size=config["batch_size"],
                                     shuffle=False,
                                     num_workers=config["num_workers"],
                                     pin_memory=True)
        _, metrics = evaluate(model,
                              dataloader,
                              config,
                              metadata,
                              test_mode=True,
                              gpu=gpu)
        for metric in metrics.keys():
            metrics_dict[f"task_{task_id}_{dataset_type}_{metric}"] = metrics[
                metric]
    dataset.load_tasks_up_to(cur_task_id)
    dataloader = data.DataLoader(dataset,
                                 batch_size=config["batch_size"],
                                 shuffle=False,
                                 num_workers=config["num_workers"],
                                 pin_memory=True)
    _, metrics = evaluate(model,
                          dataloader,
                          config,
                          metadata,
                          test_mode=True,
                          gpu=gpu)
    for metric in metrics.keys():
        metrics_dict[f"average_{dataset_type}_{metric}"] = metrics[metric]

    if rank == 0:
        utils.log_task(cur_task_id, metrics_dict, logbook)
        if distributed:
            dist.barrier()
            metrics_dict["rank"] = rank
            print_msg(metrics_dict)
    else:
        dist.barrier()
        metrics_dict["rank"] = rank
        print_msg(metrics_dict)
Example #9
0
def task_train(model,
               buffer,
               lifelong_datasets,
               config,
               metadata,
               logbook,
               dist_args=None):
    distributed = dist_args is not None
    if distributed:
        gpu = dist_args["gpu"]
        rank = dist_args["rank"]
    else:
        gpu = None
        rank = 0

    best_checkpoint = {
        "model_state_dict": deepcopy(model.method_state_dict()),
        "best_modified_jaccard": 0
    }
    best_checkpoint_file = os.path.join(config['logging_path'],
                                        "best_checkpoint")
    if config['use_best_model']:
        if config['task_epoch'] > 0 and os.path.exists(best_checkpoint_file):
            if distributed:
                best_checkpoint = torch.load(
                    best_checkpoint_file,
                    map_location=f"cuda:{dist_args['gpu']}")
            else:
                best_checkpoint = torch.load(best_checkpoint_file)

    task_train_data = lifelong_datasets['train']
    if config["method"] == "agem":
        bsm = 0.0
        task_train_data_with_buffer = TaskDataMergedWithBuffer(
            buffer, task_train_data, buffer_sampling_multiplier=bsm)
    else:
        bsm = config["buffer_sampling_multiplier"]
        task_train_data_with_buffer = TaskDataMergedWithBuffer(
            buffer, task_train_data, buffer_sampling_multiplier=bsm)
    task_valid_data = lifelong_datasets['intask_valid']
    cur_task_id = task_train_data.cur_task_id

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            task_train_data_with_buffer,
            num_replicas=dist_args["world_size"],
            rank=rank)
    else:
        train_sampler = None

    train_loader = data.DataLoader(task_train_data_with_buffer,
                                   batch_size=config["batch_size"],
                                   shuffle=(train_sampler is None),
                                   num_workers=config["num_workers"],
                                   pin_memory=True,
                                   sampler=train_sampler)
    valid_loader = data.DataLoader(task_valid_data,
                                   batch_size=config["batch_size"],
                                   shuffle=False,
                                   num_workers=config["num_workers"],
                                   pin_memory=True)

    if cur_task_id == 0:
        num_epochs = config['epochs_per_task'] * 2
        print_msg(
            f"Training for {num_epochs} epochs for the first task (double that the other tasks)"
        )
    else:
        num_epochs = config['epochs_per_task']
    print_msg(
        f"Starting training of task {cur_task_id} epoch {config['task_epoch']} till epoch {num_epochs}"
    )
    for epoch in range(config['task_epoch'], num_epochs):
        if distributed:
            train_sampler.set_epoch(epoch)
        start_time = time.time()
        log_dict = {}
        train_loss, train_metrics = epoch_train(model, train_loader, config,
                                                metadata, gpu, rank)
        log_dict[f"train_loss_{cur_task_id}"] = train_loss
        for metric in train_metrics.keys():
            log_dict[f"train_{metric}_{cur_task_id}"] = train_metrics[metric]

        valid_loss, valid_metrics = evaluate(model,
                                             valid_loader,
                                             config,
                                             metadata,
                                             test_mode=False,
                                             gpu=gpu)
        log_dict[f"valid_loss_{cur_task_id}"] = valid_loss
        for metric in valid_metrics.keys():
            log_dict[f"valid_{metric}_{cur_task_id}"] = valid_metrics[metric]

        model.net.eval()
        model.consolidate_epoch_knowledge(
            log_dict[f"valid_modified_jaccard_{cur_task_id}"],
            task_data=task_train_data,
            device=config["device"],
            batch_size=config["batch_size"])
        # If using the lmdb database, close it and open a new environment to kill active readers
        buffer.reset_lmdb_database()

        if config['use_best_model']:
            if log_dict[
                    f"valid_modified_jaccard_{cur_task_id}"] >= best_checkpoint[
                        "best_modified_jaccard"]:
                best_checkpoint["best_modified_jaccard"] = log_dict[
                    f"valid_modified_jaccard_{cur_task_id}"]
                best_checkpoint["model_state_dict"] = deepcopy(
                    model.method_state_dict())
            log_dict[
                f"best_valid_modified_jaccard_{cur_task_id}"] = best_checkpoint[
                    "best_modified_jaccard"]

        if distributed:
            dist.barrier()  #to calculate the time based on the slowest gpu
        end_time = time.time()
        log_dict[f"elapsed_time"] = round(end_time - start_time, 2)

        if rank == 0:
            utils.log(epoch, cur_task_id, log_dict, logbook)
            if distributed:
                dist.barrier()
                log_dict["rank"] = rank
                print_msg(log_dict)
        else:
            dist.barrier()
            log_dict["rank"] = rank
            print_msg(log_dict)

        # Checkpointing
        config["task_epoch"] = epoch + 1
        if (config["task_epoch"] %
                config['checkpoint_interval']) == 0 and rank == 0:
            print_msg("Saving latest checkpoint")
            save_file = os.path.join(config['logging_path'], "latest_model")
            lifelong_methods.utils.save_model(save_file, config, metadata,
                                              model, buffer, lifelong_datasets)
            if config['use_best_model']:
                print_msg("Saving best checkpoint")
                torch.save(best_checkpoint, best_checkpoint_file)

    # reset the model parameters to the best performing model
    if config['use_best_model']:
        model.load_method_state_dict(best_checkpoint["model_state_dict"])
Example #10
0
def epoch_train(model, dataloader, config, metadata, gpu=None, rank=0):
    train_loss = 0
    train_metrics = {
        'jaccard_sim': 0.,
        'modified_jaccard': 0.,
        'strict_acc': 0.,
        'recall': 0.
    }
    data_len = 0
    class_names_to_idx = metadata["class_names_to_idx"]
    num_seen_classes = len(model.seen_classes)
    model.net.train()

    minibatch_i = 0
    for minibatch in dataloader:
        labels_names = list(zip(minibatch[1], minibatch[2]))
        labels = transform_labels_names_to_vector(labels_names,
                                                  num_seen_classes,
                                                  class_names_to_idx)

        if gpu is None:
            images = minibatch[0].to(config["device"], non_blocking=True)
            labels = labels.to(config["device"], non_blocking=True)
        else:
            images = minibatch[0].to(torch.device(f"cuda:{gpu}"),
                                     non_blocking=True)
            labels = labels.to(torch.device(f"cuda:{gpu}"), non_blocking=True)

        if len(minibatch) > 3:
            if gpu is None:
                in_buffer = minibatch[3].to(config["device"],
                                            non_blocking=True)
            else:
                in_buffer = minibatch[3].to(torch.device(f"cuda:{gpu}"),
                                            non_blocking=True)
        else:
            in_buffer = None

        predictions, loss = model.observe(images,
                                          labels,
                                          in_buffer,
                                          train=True)
        labels = labels.bool()
        train_loss += loss * images.shape[0]
        train_metrics['jaccard_sim'] += metrics.jaccard_sim(
            predictions, labels) * images.shape[0]
        train_metrics['modified_jaccard'] += metrics.modified_jaccard_sim(
            predictions, labels) * images.shape[0]
        train_metrics['strict_acc'] += metrics.strict_accuracy(
            predictions, labels) * images.shape[0]
        train_metrics['recall'] += metrics.recall(predictions,
                                                  labels) * images.shape[0]
        data_len += images.shape[0]

        if minibatch_i == 0:
            print_msg(
                f"rank {rank}, max memory allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB, "
                f"current memory allocated: {torch.cuda.memory_allocated() / 1024 / 1024} MB\n"
            )
        minibatch_i += 1

    train_loss /= data_len
    train_metrics['jaccard_sim'] /= data_len
    train_metrics['modified_jaccard'] /= data_len
    train_metrics['strict_acc'] /= data_len
    train_metrics['recall'] /= data_len
    return train_loss, train_metrics
Example #11
0
def get_lifelong_datasets(
    dataset_name: str,
    dataset_root: str = "./data",
    setup: str = IIRC_SETUP,
    framework: str = PYTORCH,
    tasks_configuration_id: int = 0,
    essential_transforms_fn: Optional[Callable[[Image.Image], Any]] = None,
    augmentation_transforms_fn: Optional[Callable[[Image.Image], Any]] = None,
    cache_images: bool = False,
    joint: bool = False
) -> Tuple[Dict[str, BaseDataset], List[List[str]], Dict[str, int]]:
    """
    Get the incremental refinement learning , as well as the tasks (which contains the classes introduced at each task),
    and the index for each class corresponding to its order of appearance

    Args:
        dataset_name (str): The name of the dataset, ex: iirc_cifar100
        dataset_root (str): The directory where the dataset is/will be downloaded (default: "./data")
        setup (str): Class Incremental Learning setup (CIL) or Incremental Implicitly Refined Classification setup
            (IIRC) (default: IIRC_SETUP)
        framework (str): The framework to be used, whether PyTorch or Tensorflow. use Tensorflow for any numpy based
            dataloading  (default: PYTORCH)
        tasks_configuration_id (int): The configuration id, where each configuration corresponds to a specific tasks and
            classes order for each dataset. This id starts from 0 for each dataset. Ignore when joint is set to True
            (default: 0)
        essential_transforms_fn (Optional[Callable[[Image.Image], Any]]): A function that contains the essential
            transforms (for example, converting a pillow image to a tensor) that should be applied to each image. This
            function is applied only when the augmentation_transforms_fn is set to None (as in the case of a test set)
            or inside the disable_augmentations context (default: None)
        augmentation_transforms_fn: A function that contains the essential transforms (for example, converting a pillow
            image to a tensor) and augmentation transforms (for example, applying random cropping) that should be
            applied to each image. When this function is provided, essential_transforms_fn is not used except inside the
            disable_augmentations context (default: None)
        cache_images (bool): cache images that belong to the current task in the memory, only applicable when using the
            image path (default: False)
        joint (bool): provided all the classes in a single task for joint training (default: False)

    Returns:
        Tuple[Dict[str, BaseDataset], List[List[str]], Dict[str, int]]:

        lifelong_datasets (Dict[str, BaseDataset]): a dictionary with the keys corresponding to the four splits (train,
        intask_validation, posttask_validation, test), and the values containing the dataset object inheriting from
        BaseDataset for that split.

        tasks (List[List[str]]): a list of lists where each inner list contains the set of classes (class names) that
        will be introduced in that task (example: [[dog, cat, car], [tiger, truck, fish]]).

        class_names_to_idx (Dict[str, int]): a dictionary with the class name as key, and the class index as value
        (example: {"dog": 0, "cat": 1, ...}).
    """
    assert framework in [
        PYTORCH, TENSORFLOW
    ], f'The framework is set to neither "{PYTORCH}" nor "{TENSORFLOW}"'
    assert setup in [
        IIRC_SETUP, CIL_SETUP
    ], f'The setup is set to neither "{IIRC_SETUP}" nor "{CIL_SETUP}"'
    assert dataset_name in datasets_names, f'The dataset_name is not in {datasets_names}'
    print_msg(f"Creating {dataset_name}")

    datasets, dataset_configuration = \
        _get_dataset(dataset_name=dataset_name, dataset_root=dataset_root)
    tasks, class_names_to_idx = _get_tasks_configuration(
        dataset_name,
        tasks_configuration_id=tasks_configuration_id,
        joint=joint)

    sprcla_data_pct = dataset_configuration["superclass_data_pct"]
    subcla_data_pct = dataset_configuration["subclass_data_pct"]
    using_image_path = dataset_configuration["using_image_path"]
    sprcla_sampling_size_cap = dataset_configuration[
        "superclass_sampling_size_cap"]
    lifelong_datasets = {}

    if framework == PYTORCH:
        from iirc.lifelong_dataset.torch_dataset import Dataset
        LifeLongDataset = Dataset
    elif framework == TENSORFLOW:
        from iirc.lifelong_dataset.tensorflow_dataset import Dataset
        LifeLongDataset = Dataset
    else:
        raise NotImplementedError

    print_msg(f"Setup used: {setup}\nUsing {framework}")

    shared_arguments = dict(
        tasks=tasks,
        setup=setup,
        using_image_path=using_image_path,
        cache_images=cache_images,
        essential_transforms_fn=essential_transforms_fn,
        superclass_data_pct=sprcla_data_pct,
        subclass_data_pct=subcla_data_pct,
        superclass_sampling_size_cap=sprcla_sampling_size_cap)
    lifelong_datasets["train"] = LifeLongDataset(
        dataset=datasets["train"],
        test_mode=False,
        augmentation_transforms_fn=augmentation_transforms_fn,
        **shared_arguments)
    lifelong_datasets["intask_valid"] = LifeLongDataset(
        dataset=datasets["intask_valid"], test_mode=False, **shared_arguments)
    lifelong_datasets["posttask_valid"] = LifeLongDataset(
        dataset=datasets["posttask_valid"], test_mode=True, **shared_arguments)
    lifelong_datasets["test"] = LifeLongDataset(dataset=datasets["test"],
                                                test_mode=True,
                                                **shared_arguments)

    print_msg("Dataset created")

    return lifelong_datasets, tasks, class_names_to_idx