コード例 #1
0
def get_sample_dynamics_learn(config, data_config):
    model = config['model']
    save_path = config['save_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    cou = 0
    sum_learn = 0.0
    last_learn = 0.0

    pre_task = []
    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == 0:
            continue

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % (task_idx - 1), 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))

        pre_task.extend(test_task)
        test_task = task

        test_taskset = taskset.Taskset(data_path, pre_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = task_idx * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        pre_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes,
                                                    new_cml_classes, device)

        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        cur_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes,
                                                    new_cml_classes, device)

        last_learn = (pre_sample_dynamics_value[2] + cur_sample_dynamics_value[0] + cur_sample_dynamics_value[2]) / (
                pre_sample_dynamics_value[2] + pre_sample_dynamics_value[3] + len(test_taskset))
        sum_learn += last_learn
        cou += 1

    avg_learn = sum_learn / cou
    return last_learn, avg_learn
コード例 #2
0
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model = config['model']
    device = config['device']
    data_path = config['data_path']
    save_path = config['save_path']

    num_classes = data_config['total_classes']
    dataset = data_config['dataset']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']

    num_labels = method_config['temporal_ensemble']['num_labels']

    test_task = []
    train_task = []

    train_taskset = taskset.Taskset(data_path,
                                    train=True,
                                    transform=train_transform,
                                    num_labels=num_labels,
                                    num_classes=num_classes)
    test_taskset = taskset.Taskset(data_path,
                                   train=False,
                                   transform=test_transform)
    '''Make network'''
    net = selector.model(model, device, num_classes)

    _train(net, train_taskset, test_taskset, config, method_config,
           data_config, logger)
コード例 #3
0
def get_sample_dynamics_intransigence(config, data_config):
    model = config['model']
    save_path = config['save_path']
    ref_path = config['ref_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    cou = 0
    sum_intransigence = 0.0
    last_intransigence = 0.0

    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == 0:
            continue

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(ref_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))

        test_task = task
        test_taskset = taskset.Taskset(data_path, test_task, task_idx, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 1) * classes_per_task

        sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes,
                                                device)

        last_intransigence = sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1])
        sum_intransigence += last_intransigence
        cou += 1

    avg_intransigence = sum_intransigence / cou
    return last_intransigence, avg_intransigence
コード例 #4
0
def get_sample_dynamics_forgetting_list(config, data_config):
    model = config['model']
    save_path = config['save_path']
    data_path = config['data_path']
    batch_size = config['batch_size']
    classes_per_task = config['classes_per_task']
    num_workers = config['num_workers']
    device = config['device']

    total_classes = data_config['total_classes']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    forgetting_list = []

    test_task = []
    for task_idx, task in enumerate(curriculum):
        if task_idx == len(curriculum) - 1:
            break

        base_model = selector.model(model, device, total_classes)
        new_model = selector.model(model, device, total_classes)

        utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar'))
        utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % (task_idx + 1), 'LastModel.pth.tar'))

        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size,
                                                  shuffle=False, num_workers=num_workers)

        base_cml_classes = (task_idx + 1) * classes_per_task
        new_cml_classes = (task_idx + 2) * classes_per_task

        sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes,
                                                device)

        forgetting_list.append(100 * (sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1])))

    return forgetting_list
コード例 #5
0
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            transform (dict):
                train (transforms): transforms for train dataset.
                test (transforms): transforms for test dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model = config['model']
    classes_per_task = config['classes_per_task']
    memory_cap = config['memory_cap']
    device = config['device']
    data_path = config['data_path']
    external_data_path = method_config['external_data_path']
    external_cifar_data_path = method_config['external_cifar_data_path']
    save_path = config['save_path']

    total_classes = data_config['total_classes']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']
    dataset = data_config['dataset']
    '''Split curriculum [[task0], [task1], [task2], ...]'''
    curriculum = [
        curriculum[x:x + classes_per_task]
        for x in range(0, total_classes, classes_per_task)
    ]
    '''Make sample memory'''
    sample_memory = taskset.SampleMemory(data_path,
                                         total_classes,
                                         len(curriculum),
                                         curriculum,
                                         transform=train_transform,
                                         capacity=memory_cap)

    external_taskset = []
    if dataset == 'cifar100':
        external_data_transform = method_config['external_data_transform']
        external_taskset = ImageFolder(os.path.join(external_data_path,
                                                    'train'),
                                       transform=external_data_transform)
    if dataset == 'cifar10':
        external_data_transform = method_config[
            'external_cifar_data_transform']
        external_taskset = ImageFolder(os.path.join(external_cifar_data_path,
                                                    'train'),
                                       transform=external_data_transform)

    test_task = []
    train_task = []
    old_net = None
    cur_net = None
    '''Taskwise iteration'''
    for task_idx, task in enumerate(curriculum):
        train_task = task
        train_taskset = taskset.Taskset(data_path,
                                        train_task,
                                        task_idx,
                                        train=True,
                                        transform=train_transform)
        val_taskset = taskset.Taskset(data_path,
                                      train_task,
                                      task_idx,
                                      train=False,
                                      transform=test_transform)
        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path,
                                       test_task,
                                       0,
                                       train=False,
                                       transform=test_transform)
        '''Make directory of current task'''
        if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)):
            os.makedirs(os.path.join(save_path, 'task%02d' % task_idx))

        config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx)
        config['cml_classes'] = len(test_task)
        '''Make network'''
        net = selector.model(model, device, classes_per_task)
        if task_idx == 0:
            old_net = _train(task_idx, net, train_taskset, sample_memory,
                             val_taskset, config, method_config, data_config,
                             logger)
        else:
            cur_net = _train(task_idx, net, train_taskset, sample_memory,
                             val_taskset, config, method_config, data_config,
                             logger)

        if cur_net is not None:
            old_net = consolidation(task_idx, old_net, cur_net,
                                    external_taskset, test_taskset, config,
                                    method_config, data_config, logger)

        sample_memory.update()
        torch.save(sample_memory,
                   os.path.join(config['task_path'], 'sample_memory'))
コード例 #6
0
def train(config, method_config, data_config, logger):
    """
    Args:
        config (dict): config file dictionary.
            model (str): name of network. [selector.model(model, ...)]
            classes_per_task (int): classes per task.
            DA (bool): if True, apply data augment.
            memory_cap (int): sample memory size.
            num_workers (int): how many subprocesses to use for data loading.
                               0 means that the data will be loaded in the main process. (default: 0)
            batch_size (int): how many samples per batch to load. (default: 1)
            device (torch.device): gpu or cpu.
            data_path (str): root directory of dataset.
            save_path (str): directory for save. (not taskwise)
        data_config (dict): data config file dictionary.
            dataset (str): name of dataset.
            total_classes (int): total class number of dataset.
            transform (dict):
                train (transforms): transforms for train dataset.
                test (transforms): transforms for test dataset.
            curriculums (list): curriculum list.
            classes (list): class name list.
        method_config (dict): method config file dictionary.
            method (str): name of method.
            process_list (list): process list.
            package (string): current package name.
        logger (Logger): logger for the tensorboard.
    """
    model_ = config['model']
    classes_per_task = config['classes_per_task']
    memory_cap = config['memory_cap']
    device = config['device']
    data_path = config['data_path']
    save_path = config['save_path']

    total_classes = data_config['total_classes']
    train_transform = data_config['transform']['train']
    test_transform = data_config['transform']['test']
    curriculum = data_config['curriculums']

    '''Split curriculum [[task0], [task1], [task2], ...]'''
    curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)]

    '''Make sample memory'''
    sample_memory = taskset.SampleMemory(data_path, total_classes, len(curriculum), curriculum,
                                         transform=train_transform, capacity=memory_cap)

    model = selector.model(model_, device, total_classes, classes_per_task)
    fisher = None
    importance = 75000
    alpha=0.9
    normalize=True
    test_task = []
    train_task = []
    '''Taskwise iteration'''
    for task_idx, task in enumerate(curriculum):
        train_task = task
        train_taskset = taskset.Taskset(data_path, train_task, task_idx, train=True, transform=train_transform)
        test_task.extend(task)
        test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform)

        '''Make directory of current task'''
        if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)):
            os.makedirs(os.path.join(save_path, 'task%02d' % task_idx))

        config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx)
        config['cml_classes'] = len(test_task)
        

        ###
        """
        model = selector.model(model, device, len(test_task))
        if method_config["flag_ewc"] == True:
            for n, g in net.named_parameters():
                if 'fc' in n:
                    g[:len(method_config["old_param"][n])].data.copy_(method_config["old_param"][n].data.clone())
                else:
                    g.data.copy_(method_config["old_param"][n].data.clone())
        """
        ###
        model_old = deepcopy(model)
        for p in model_old.parameters():
            p.requires_grad = False
        ewc = EWCPPLoss(model, model_old, fisher=fisher, alpha=alpha, normalize=normalize)
        model, ewc = _train(task_idx, model, ewc, importance, train_taskset, sample_memory, test_taskset, config, method_config, data_config, logger)

        fisher = deepcopy(ewc.get_fisher())
        #sample_memory.update()
        torch.save(sample_memory, os.path.join(config['task_path'], 'sample_memory'))