def get_sample_dynamics_learn(config, data_config):
    model = config['model']
    save_path = config['save_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    cou = 0
    sum_learn = 0.0
    last_learn = 0.0

    pre_task = []
    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == 0:
            continue

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % (task_idx - 1), 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))

        pre_task.extend(test_task)
        test_task = task

        test_taskset = taskset.Taskset(data_path, pre_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = task_idx * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        pre_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes,
                                                    new_cml_classes, device)

        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        cur_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes,
                                                    new_cml_classes, device)

        last_learn = (pre_sample_dynamics_value[2] + cur_sample_dynamics_value[0] + cur_sample_dynamics_value[2]) / (
                pre_sample_dynamics_value[2] + pre_sample_dynamics_value[3] + len(test_taskset))
        sum_learn += last_learn
        cou += 1

    avg_learn = sum_learn / cou
    return last_learn, avg_learn
def get_sample_dynamics_intransigence(config, data_config):
    model = config['model']
    save_path = config['save_path']
    ref_path = config['ref_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    cou = 0
    sum_intransigence = 0.0
    last_intransigence = 0.0

    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == 0:
            continue

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(ref_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))

        test_task = task
        test_taskset = taskset.Taskset(data_path, test_task, task_idx, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes,
                                                device)

        last_intransigence = sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1])
        sum_intransigence += last_intransigence
        cou += 1

    avg_intransigence = sum_intransigence / cou
    return last_intransigence, avg_intransigence
def get_sample_dynamics_forgetting_list(config, data_config):
    model = config['model']
    save_path = config['save_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    forgetting_list = []

    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == len(curriculum) - 1:
            break

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % (task_idx + 1), 'LastModel.pth.tar'))

        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 2) * classes_per_task

        sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes,
                                                device)

        forgetting_list.append(100 * (sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1])))

    return forgetting_list
Пример #4
0
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model = config['model']
    device = config['device']
    data_path = config['data_path']
    save_path = config['save_path']

    num_classes = data_config['total_classes']
    dataset = data_config['dataset']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']

    num_labels = method_config['temporal_ensemble']['num_labels']

    test_task = []
    train_task = []

    train_taskset = taskset.Taskset(data_path,
                                    train=True,
                                    transform=train_transform,
                                    num_labels=num_labels,
                                    num_classes=num_classes)
    test_taskset = taskset.Taskset(data_path,
                                   train=False,
                                   transform=test_transform)
    '''Make network'''
    net = selector.model(model, device, num_classes)

    _train(net, train_taskset, test_taskset, config, method_config,
           data_config, logger)
Пример #5
0
def consolidation(task_idx, old_net, cur_net, taskset, test_taskset, config,
                  method_config, data_config, logger):

    model = config['model']
    cml_classes = config['cml_classes']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    task_path = config['task_path']
    num_workers = config['num_workers']
    device = config['device']

    print(model, cml_classes, batch_size, classes_per_task)

    process_list = method_config['consolidation_process_list']

    log = utils.Log(task_path)
    epoch = 0
    single_best_accuracy, multi_best_accuracy = 0.0, 0.0
    net = selector.model(model, device, cml_classes)

    for param in old_net.parameters():
        param.requires_grad = False

    for param in cur_net.parameters():
        param.requires_grad = False

    for process in process_list:
        epochs = process['epochs']
        optimizer = process['optimizer'](net.parameters())
        scheduler = process['scheduler'](optimizer)
        criterion = nn.MSELoss()

        train_loader = torch.utils.data.DataLoader(taskset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_taskset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=num_workers)

        log.info("Start Consolidation")
        for ep in range(epochs):
            log.info("%d Epoch Started" % epoch)
            net.train()
            old_net.eval()
            cur_net.eval()
            epoch_loss = 0.0
            total = 0

            for i, data in enumerate(train_loader):
                utils.printProgressBar(i + 1,
                                       len(train_loader),
                                       prefix='train')
                images = data[0].to(device)
                cur_batch_size = images.size(0)

                optimizer.zero_grad()

                outputs_old = old_net(images)
                outputs_cur = cur_net(images)
                outputs_old -= outputs_old.mean(dim=1).reshape(
                    cur_batch_size, -1)
                outputs_cur -= outputs_cur.mean(dim=1).reshape(
                    cur_batch_size, -1)

                outputs_tot = torch.cat((outputs_old, outputs_cur), dim=1)
                outputs = net(images)
                loss = criterion(outputs, outputs_tot)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item() * cur_batch_size
                total += cur_batch_size

            epoch_loss /= total
            selector.scheduler.step(scheduler, epoch_loss)

            log.info("epoch: %d  train_loss: %.3lf  train_sample: %d" %
                     (epoch, epoch_loss, total))

            if ep == (epochs - 1):
                test_loss, single_total_accuracy, single_class_accuracy, multi_total_accuracy, multi_class_accuracy = \
                    test(net, test_loader, config, data_config, task_idx, False)
                torch.save(net, os.path.join(task_path, 'consolidated_model'))
            logger.epoch_step()
            epoch += 1

        log.info("Finish Consolidation")

    return net
Пример #6
0
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            transform (dict):
                train (transforms): transforms for train dataset.
                test (transforms): transforms for test dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model = config['model']
    classes_per_task = config['classes_per_task']
    memory_cap = config['memory_cap']
    device = config['device']
    data_path = config['data_path']
    external_data_path = method_config['external_data_path']
    external_cifar_data_path = method_config['external_cifar_data_path']
    save_path = config['save_path']

    total_classes = data_config['total_classes']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    dataset = data_config['dataset']
    '''Split curriculum [[task0], [task1], [task2], ...]'''
    curriculum = [
        curriculum[x:x + classes_per_task]
        for x in range(0, total_classes, classes_per_task)
    ]
    '''Make sample memory'''
    sample_memory = taskset.SampleMemory(data_path,
                                         total_classes,
                                         len(curriculum),
                                         curriculum,
                                         transform=train_transform,
                                         capacity=memory_cap)

    external_taskset = []
    if dataset == 'cifar100':
        external_data_transform = method_config['external_data_transform']
        external_taskset = ImageFolder(os.path.join(external_data_path,
                                                    'train'),
                                       transform=external_data_transform)
    if dataset == 'cifar10':
        external_data_transform = method_config[
            'external_cifar_data_transform']
        external_taskset = ImageFolder(os.path.join(external_cifar_data_path,
                                                    'train'),
                                       transform=external_data_transform)

    test_task = []
    train_task = []
    old_net = None
    cur_net = None
    '''Taskwise iteration'''
    for task_idx, task in enumerate(curriculum):
        train_task = task
        train_taskset = taskset.Taskset(data_path,
                                        train_task,
                                        task_idx,
                                        train=True,
                                        transform=train_transform)
        val_taskset = taskset.Taskset(data_path,
                                      train_task,
                                      task_idx,
                                      train=False,
                                      transform=test_transform)
        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path,
                                       test_task,
                                       0,
                                       train=False,
                                       transform=test_transform)
        '''Make directory of current task'''
        if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)):
            os.makedirs(os.path.join(save_path, 'task%02d' % task_idx))

        config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx)
        config['cml_classes'] = len(test_task)
        '''Make network'''
        net = selector.model(model, device, classes_per_task)
        if task_idx == 0:
            old_net = _train(task_idx, net, train_taskset, sample_memory,
                             val_taskset, config, method_config, data_config,
                             logger)
        else:
            cur_net = _train(task_idx, net, train_taskset, sample_memory,
                             val_taskset, config, method_config, data_config,
                             logger)

        if cur_net is not None:
            old_net = consolidation(task_idx, old_net, cur_net,
                                    external_taskset, test_taskset, config,
                                    method_config, data_config, logger)

        sample_memory.update()
        torch.save(sample_memory,
                   os.path.join(config['task_path'], 'sample_memory'))
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            transform (dict):
                train (transforms): transforms for train dataset.
                test (transforms): transforms for test dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model_ = config['model']
    classes_per_task = config['classes_per_task']
    memory_cap = config['memory_cap']
    device = config['device']
    data_path = config['data_path']
    save_path = config['save_path']

    total_classes = data_config['total_classes']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']

    '''Split curriculum [[task0], [task1], [task2], ...]'''
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    '''Make sample memory'''
    sample_memory = taskset.SampleMemory(data_path, total_classes, len(curriculum), curriculum,
                                         transform=train_transform, capacity=memory_cap)

    model = selector.model(model_, device, total_classes, classes_per_task)
    fisher = None
    importance = 75000
    alpha=0.9
    normalize=True
    test_task = []
    train_task = []
    '''Taskwise iteration'''
    for task_idx, task in enumerate(curriculum):
        train_task = task
        train_taskset = taskset.Taskset(data_path, train_task, task_idx, train=True, transform=train_transform)
        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)

        '''Make directory of current task'''
        if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)):
            os.makedirs(os.path.join(save_path, 'task%02d' % task_idx))

        config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx)
        config['cml_classes'] = len(test_task)
        

        ###
        """
        model = selector.model(model, device, len(test_task))
        if method_config["flag_ewc"] == True:
            for n, g in net.named_parameters():
                if 'fc' in n:
                    g[:len(method_config["old_param"][n])].data.copy_(method_config["old_param"][n].data.clone())
                else:
                    g.data.copy_(method_config["old_param"][n].data.clone())
        """
        ###
        model_old = deepcopy(model)
        for p in model_old.parameters():
            p.requires_grad = False
        ewc = EWCPPLoss(model, model_old, fisher=fisher, alpha=alpha, normalize=normalize)
        model, ewc = _train(task_idx, model, ewc, importance, train_taskset, sample_memory, test_taskset, config, method_config, data_config, logger)

        fisher = deepcopy(ewc.get_fisher())
        #sample_memory.update()
        torch.save(sample_memory, os.path.join(config['task_path'], 'sample_memory'))