def get_sample_dynamics_learn(config, data_config): model = config['model'] save_path = config['save_path'] data_path = config['data_path'] batch_size = config['batch_size'] classes_per_task = config['classes_per_task'] num_workers = config['num_workers'] device = config['device'] total_classes = data_config['total_classes'] test_transform = data_config['transform']['test'] curriculum = data_config['curriculums'] curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)] cou = 0 sum_learn = 0.0 last_learn = 0.0 pre_task = [] test_task = [] for task_idx, task in enumerate(curriculum): if task_idx == 0: continue base_model = selector.model(model, device, total_classes) new_model = selector.model(model, device, total_classes) utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % (task_idx - 1), 'LastModel.pth.tar')) utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar')) pre_task.extend(test_task) test_task = task test_taskset = taskset.Taskset(data_path, pre_task, 0, train=False, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size, shuffle=False, num_workers=num_workers) base_cml_classes = task_idx * classes_per_task new_cml_classes = (task_idx + 1) * classes_per_task pre_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes, device) test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size, shuffle=False, num_workers=num_workers) base_cml_classes = (task_idx + 1) * classes_per_task new_cml_classes = (task_idx + 1) * classes_per_task cur_sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes, device) last_learn = (pre_sample_dynamics_value[2] + cur_sample_dynamics_value[0] + cur_sample_dynamics_value[2]) / ( pre_sample_dynamics_value[2] + pre_sample_dynamics_value[3] + len(test_taskset)) sum_learn += last_learn cou += 1 avg_learn = sum_learn / cou return last_learn, avg_learn
def get_sample_dynamics_intransigence(config, data_config): model = config['model'] save_path = config['save_path'] ref_path = config['ref_path'] data_path = config['data_path'] batch_size = config['batch_size'] classes_per_task = config['classes_per_task'] num_workers = config['num_workers'] device = config['device'] total_classes = data_config['total_classes'] test_transform = data_config['transform']['test'] curriculum = data_config['curriculums'] curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)] cou = 0 sum_intransigence = 0.0 last_intransigence = 0.0 test_task = [] for task_idx, task in enumerate(curriculum): if task_idx == 0: continue base_model = selector.model(model, device, total_classes) new_model = selector.model(model, device, total_classes) utils.load_checkpoint(base_model, os.path.join(ref_path, 'task%02d' % task_idx, 'LastModel.pth.tar')) utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar')) test_task = task test_taskset = taskset.Taskset(data_path, test_task, task_idx, train=False, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size, shuffle=False, num_workers=num_workers) base_cml_classes = (task_idx + 1) * classes_per_task new_cml_classes = (task_idx + 1) * classes_per_task sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes, device) last_intransigence = sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1]) sum_intransigence += last_intransigence cou += 1 avg_intransigence = sum_intransigence / cou return last_intransigence, avg_intransigence
def get_sample_dynamics_forgetting_list(config, data_config): model = config['model'] save_path = config['save_path'] data_path = config['data_path'] batch_size = config['batch_size'] classes_per_task = config['classes_per_task'] num_workers = config['num_workers'] device = config['device'] total_classes = data_config['total_classes'] test_transform = data_config['transform']['test'] curriculum = data_config['curriculums'] curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)] forgetting_list = [] test_task = [] for task_idx, task in enumerate(curriculum): if task_idx == len(curriculum) - 1: break base_model = selector.model(model, device, total_classes) new_model = selector.model(model, device, total_classes) utils.load_checkpoint(base_model, os.path.join(save_path, 'task%02d' % task_idx, 'LastModel.pth.tar')) utils.load_checkpoint(new_model, os.path.join(save_path, 'task%02d' % (task_idx + 1), 'LastModel.pth.tar')) test_task.extend(task) test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform) test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size, shuffle=False, num_workers=num_workers) base_cml_classes = (task_idx + 1) * classes_per_task new_cml_classes = (task_idx + 2) * classes_per_task sample_dynamics_value = sample_dynamics(base_model, new_model, test_loader, base_cml_classes, new_cml_classes, device) forgetting_list.append(100 * (sample_dynamics_value[1] / (sample_dynamics_value[0] + sample_dynamics_value[1]))) return forgetting_list
def train(config, method_config, data_config, logger): """ Args: config (dict): config file dictionary. model (str): name of network. [selector.model(model, ...)] classes_per_task (int): classes per task. DA (bool): if True, apply data augment. memory_cap (int): sample memory size. num_workers (int): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) batch_size (int): how many samples per batch to load. (default: 1) device (torch.device): gpu or cpu. data_path (str): root directory of dataset. save_path (str): directory for save. (not taskwise) data_config (dict): data config file dictionary. dataset (str): name of dataset. total_classes (int): total class number of dataset. curriculums (list): curriculum list. classes (list): class name list. method_config (dict): method config file dictionary. method (str): name of method. process_list (list): process list. package (string): current package name. logger (Logger): logger for the tensorboard. """ model = config['model'] device = config['device'] data_path = config['data_path'] save_path = config['save_path'] num_classes = data_config['total_classes'] dataset = data_config['dataset'] train_transform = data_config['transform']['train'] test_transform = data_config['transform']['test'] num_labels = method_config['temporal_ensemble']['num_labels'] test_task = [] train_task = [] train_taskset = taskset.Taskset(data_path, train=True, transform=train_transform, num_labels=num_labels, num_classes=num_classes) test_taskset = taskset.Taskset(data_path, train=False, transform=test_transform) '''Make network''' net = selector.model(model, device, num_classes) _train(net, train_taskset, test_taskset, config, method_config, data_config, logger)
def consolidation(task_idx, old_net, cur_net, taskset, test_taskset, config, method_config, data_config, logger): model = config['model'] cml_classes = config['cml_classes'] batch_size = config['batch_size'] classes_per_task = config['classes_per_task'] task_path = config['task_path'] num_workers = config['num_workers'] device = config['device'] print(model, cml_classes, batch_size, classes_per_task) process_list = method_config['consolidation_process_list'] log = utils.Log(task_path) epoch = 0 single_best_accuracy, multi_best_accuracy = 0.0, 0.0 net = selector.model(model, device, cml_classes) for param in old_net.parameters(): param.requires_grad = False for param in cur_net.parameters(): param.requires_grad = False for process in process_list: epochs = process['epochs'] optimizer = process['optimizer'](net.parameters()) scheduler = process['scheduler'](optimizer) criterion = nn.MSELoss() train_loader = torch.utils.data.DataLoader(taskset, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader(test_taskset, batch_size=batch_size, shuffle=False, num_workers=num_workers) log.info("Start Consolidation") for ep in range(epochs): log.info("%d Epoch Started" % epoch) net.train() old_net.eval() cur_net.eval() epoch_loss = 0.0 total = 0 for i, data in enumerate(train_loader): utils.printProgressBar(i + 1, len(train_loader), prefix='train') images = data[0].to(device) cur_batch_size = images.size(0) optimizer.zero_grad() outputs_old = old_net(images) outputs_cur = cur_net(images) outputs_old -= outputs_old.mean(dim=1).reshape( cur_batch_size, -1) outputs_cur -= outputs_cur.mean(dim=1).reshape( cur_batch_size, -1) outputs_tot = torch.cat((outputs_old, outputs_cur), dim=1) outputs = net(images) loss = criterion(outputs, outputs_tot) loss.backward() optimizer.step() epoch_loss += loss.item() * cur_batch_size total += cur_batch_size epoch_loss /= total selector.scheduler.step(scheduler, epoch_loss) log.info("epoch: %d train_loss: %.3lf train_sample: %d" % (epoch, epoch_loss, total)) if ep == (epochs - 1): test_loss, single_total_accuracy, single_class_accuracy, multi_total_accuracy, multi_class_accuracy = \ test(net, test_loader, config, data_config, task_idx, False) torch.save(net, os.path.join(task_path, 'consolidated_model')) logger.epoch_step() epoch += 1 log.info("Finish Consolidation") return net
def train(config, method_config, data_config, logger): """ Args: config (dict): config file dictionary. model (str): name of network. [selector.model(model, ...)] classes_per_task (int): classes per task. DA (bool): if True, apply data augment. memory_cap (int): sample memory size. num_workers (int): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) batch_size (int): how many samples per batch to load. (default: 1) device (torch.device): gpu or cpu. data_path (str): root directory of dataset. save_path (str): directory for save. (not taskwise) data_config (dict): data config file dictionary. dataset (str): name of dataset. total_classes (int): total class number of dataset. transform (dict): train (transforms): transforms for train dataset. test (transforms): transforms for test dataset. curriculums (list): curriculum list. classes (list): class name list. method_config (dict): method config file dictionary. method (str): name of method. process_list (list): process list. package (string): current package name. logger (Logger): logger for the tensorboard. """ model = config['model'] classes_per_task = config['classes_per_task'] memory_cap = config['memory_cap'] device = config['device'] data_path = config['data_path'] external_data_path = method_config['external_data_path'] external_cifar_data_path = method_config['external_cifar_data_path'] save_path = config['save_path'] total_classes = data_config['total_classes'] train_transform = data_config['transform']['train'] test_transform = data_config['transform']['test'] curriculum = data_config['curriculums'] dataset = data_config['dataset'] '''Split curriculum [[task0], [task1], [task2], ...]''' curriculum = [ curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task) ] '''Make sample memory''' sample_memory = taskset.SampleMemory(data_path, total_classes, len(curriculum), curriculum, transform=train_transform, capacity=memory_cap) external_taskset = [] if dataset == 'cifar100': external_data_transform = method_config['external_data_transform'] external_taskset = ImageFolder(os.path.join(external_data_path, 'train'), transform=external_data_transform) if dataset == 'cifar10': external_data_transform = method_config[ 'external_cifar_data_transform'] external_taskset = ImageFolder(os.path.join(external_cifar_data_path, 'train'), transform=external_data_transform) test_task = [] train_task = [] old_net = None cur_net = None '''Taskwise iteration''' for task_idx, task in enumerate(curriculum): train_task = task train_taskset = taskset.Taskset(data_path, train_task, task_idx, train=True, transform=train_transform) val_taskset = taskset.Taskset(data_path, train_task, task_idx, train=False, transform=test_transform) test_task.extend(task) test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform) '''Make directory of current task''' if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)): os.makedirs(os.path.join(save_path, 'task%02d' % task_idx)) config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx) config['cml_classes'] = len(test_task) '''Make network''' net = selector.model(model, device, classes_per_task) if task_idx == 0: old_net = _train(task_idx, net, train_taskset, sample_memory, val_taskset, config, method_config, data_config, logger) else: cur_net = _train(task_idx, net, train_taskset, sample_memory, val_taskset, config, method_config, data_config, logger) if cur_net is not None: old_net = consolidation(task_idx, old_net, cur_net, external_taskset, test_taskset, config, method_config, data_config, logger) sample_memory.update() torch.save(sample_memory, os.path.join(config['task_path'], 'sample_memory'))
def train(config, method_config, data_config, logger): """ Args: config (dict): config file dictionary. model (str): name of network. [selector.model(model, ...)] classes_per_task (int): classes per task. DA (bool): if True, apply data augment. memory_cap (int): sample memory size. num_workers (int): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) batch_size (int): how many samples per batch to load. (default: 1) device (torch.device): gpu or cpu. data_path (str): root directory of dataset. save_path (str): directory for save. (not taskwise) data_config (dict): data config file dictionary. dataset (str): name of dataset. total_classes (int): total class number of dataset. transform (dict): train (transforms): transforms for train dataset. test (transforms): transforms for test dataset. curriculums (list): curriculum list. classes (list): class name list. method_config (dict): method config file dictionary. method (str): name of method. process_list (list): process list. package (string): current package name. logger (Logger): logger for the tensorboard. """ model_ = config['model'] classes_per_task = config['classes_per_task'] memory_cap = config['memory_cap'] device = config['device'] data_path = config['data_path'] save_path = config['save_path'] total_classes = data_config['total_classes'] train_transform = data_config['transform']['train'] test_transform = data_config['transform']['test'] curriculum = data_config['curriculums'] '''Split curriculum [[task0], [task1], [task2], ...]''' curriculum = [curriculum[x:x + classes_per_task] for x in range(0, total_classes, classes_per_task)] '''Make sample memory''' sample_memory = taskset.SampleMemory(data_path, total_classes, len(curriculum), curriculum, transform=train_transform, capacity=memory_cap) model = selector.model(model_, device, total_classes, classes_per_task) fisher = None importance = 75000 alpha=0.9 normalize=True test_task = [] train_task = [] '''Taskwise iteration''' for task_idx, task in enumerate(curriculum): train_task = task train_taskset = taskset.Taskset(data_path, train_task, task_idx, train=True, transform=train_transform) test_task.extend(task) test_taskset = taskset.Taskset(data_path, test_task, 0, train=False, transform=test_transform) '''Make directory of current task''' if not os.path.exists(os.path.join(save_path, 'task%02d' % task_idx)): os.makedirs(os.path.join(save_path, 'task%02d' % task_idx)) config['task_path'] = os.path.join(save_path, 'task%02d' % task_idx) config['cml_classes'] = len(test_task) ### """ model = selector.model(model, device, len(test_task)) if method_config["flag_ewc"] == True: for n, g in net.named_parameters(): if 'fc' in n: g[:len(method_config["old_param"][n])].data.copy_(method_config["old_param"][n].data.clone()) else: g.data.copy_(method_config["old_param"][n].data.clone()) """ ### model_old = deepcopy(model) for p in model_old.parameters(): p.requires_grad = False ewc = EWCPPLoss(model, model_old, fisher=fisher, alpha=alpha, normalize=normalize) model, ewc = _train(task_idx, model, ewc, importance, train_taskset, sample_memory, test_taskset, config, method_config, data_config, logger) fisher = deepcopy(ewc.get_fisher()) #sample_memory.update() torch.save(sample_memory, os.path.join(config['task_path'], 'sample_memory'))