def train_layerwise(train_args, model, module_schedule, train_loader, val_loader, device=0): train_modules = [] # modules passed in list will be added to optimization set in reverse order for modules in module_schedule: for m in modules: if hasattr(m, 'update_component'): m.update_component() train_modules += list(modules) train_args.epochs = len(train_modules) train(train_args, model, train_loader, val_loader, device=device, optimize_modules=train_modules, multihead=True)
def consolidate_multi_task(data_args, train_args, model, device=0): if train_args.regularization != 'none': train_args.l2 = True train_loaders, val_loaders, test_loaders = get_dataloaders_incr(data_args, load_test=True) _, _, test_ldr = get_dataloaders(data_args, load_train=False) reinit_layers = find_network_modules_by_name(model, train_args.layer) if train_args.superimpose: # define SuperConv model-wise apply methods model.superimpose = apply_module_method_if_exists(model, 'superimpose') model.load_superimposed_weight = apply_module_method_if_exists(model, 'load_superimposed_weight') model.update_component = apply_module_method_if_exists(model, 'update_component') model.scale_supconv_grads = apply_module_method_if_exists(model, 'scale_grad') def build_super_conv(conv): in_ch = conv.in_channels // train_args.redundant_groups * train_args.redundant_groups out_ch = conv.out_channels // train_args.redundant_groups * train_args.redundant_groups return SuperConv(in_ch, out_ch, conv.kernel_size, bias=conv.bias is not None, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups * train_args.redundant_groups, drop_groups=train_args.drop_groups, bias_sup=train_args.weight_sup_method) for i, layer_name in enumerate(train_args.layer): old_conv = reinit_layers[i] if type(old_conv) != nn.Conv2d: continue sup_conv = build_super_conv(old_conv) set_torchvision_network_module(model, layer_name, sup_conv) sup_conv.cuda() reinit_layers[i] = sup_conv elif train_args.l2: model.update_previous_params = apply_module_method_if_exists(model, 'update_previous_weight') def build_l2_conv(conv): in_ch = conv.in_channels // train_args.redundant_groups * train_args.redundant_groups out_ch = conv.out_channels // train_args.redundant_groups * train_args.redundant_groups return L2Conv(in_ch, out_ch, conv.kernel_size, bias=conv.bias is not None, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups * train_args.redundant_groups) for i, layer_name in enumerate(train_args.layer): old_conv = reinit_layers[i] if type(old_conv) != nn.Conv2d: continue l2_conv = build_l2_conv(old_conv) set_torchvision_network_module(model, layer_name, l2_conv) l2_conv.cuda() reinit_layers[i] = l2_conv # disable affine and running stats of retrained bn layers """model.bn1 = nn.BatchNorm2d(model.bn1.num_features, affine=False).cuda() model.layer1[0].bn1 = nn.BatchNorm2d(model.layer1[0].bn1.num_features, affine=False).cuda() model.layer1[0].bn2 = nn.BatchNorm2d(model.layer1[0].bn2.num_features, affine=False).cuda() model.layer1[1].bn1 = nn.BatchNorm2d(model.layer1[1].bn1.num_features, affine=False).cuda() model.layer1[1].bn2 = nn.BatchNorm2d(model.layer1[1].bn2.num_features, affine=False).cuda() model.layer2[0].bn1 = nn.BatchNorm2d(model.layer2[0].bn1.num_features, affine=False).cuda() model.layer2[0].bn2 = nn.BatchNorm2d(model.layer2[0].bn2.num_features, affine=False).cuda() model.layer2[0].downsample[1] = nn.BatchNorm2d(model.layer2[0].downsample[1].num_features, affine=False).cuda() model.layer2[1].bn1 = nn.BatchNorm2d(model.layer2[1].bn1.num_features, affine=False).cuda() model.layer2[1].bn2 = nn.BatchNorm2d(model.layer2[1].bn2.num_features, affine=False).cuda() model.layer3[0].bn1 = nn.BatchNorm2d(model.layer3[0].bn1.num_features, affine=False).cuda() model.layer3[0].bn2 = nn.BatchNorm2d(model.layer3[0].bn2.num_features, affine=False).cuda() model.layer3[0].downsample[1] = nn.BatchNorm2d(model.layer3[0].downsample[1].num_features, affine=False).cuda() model.layer3[1].bn1 = nn.BatchNorm2d(model.layer3[1].bn1.num_features, affine=False).cuda() model.layer3[1].bn2 = nn.BatchNorm2d(model.layer3[1].bn2.num_features, affine=False).cuda()""" model.layer4[0].bn1 = nn.BatchNorm2d(model.layer4[0].bn1.num_features, affine=False).cuda() model.layer4[0].bn2 = nn.BatchNorm2d(model.layer4[0].bn2.num_features, affine=False).cuda() model.layer4[0].downsample[1] = nn.BatchNorm2d(model.layer4[0].downsample[1].num_features, affine=False).cuda() model.layer4[1].bn1 = nn.BatchNorm2d(model.layer4[1].bn1.num_features, affine=False).cuda() model.layer4[1].bn2 = nn.BatchNorm2d(model.layer4[1].bn2.num_features, affine=False).cuda() model.eval() # if not updating bn layer during training, disable model's train mode if not train_args.fit_bn_stats: model.train = lambda *args, **kwargs: None # test pretrained model accuracy """pt_accuracies = [] for i, test_loader in enumerate(test_loaders): c, t = test(model, test_loader, device=device, multihead=True) acc = (c.sum() / t.sum()).item() print('Pretrained model accuracy for task %d: %.2f' % (i, acc * 100.)) pt_accuracies += [acc]""" def save_layer(save_path, suffix='.pth'): for layer, name in zip(reinit_layers, train_args.layer): layer.cpu() torch.save(layer.state_dict(), save_path + name + suffix) layer.cuda() def load_layer(load_path, suffix='.pth'): for layer, name in zip(reinit_layers, train_args.layer): layer.cpu() layer.load_state_dict(torch.load(load_path + name + suffix)) layer.cuda() base_dir = 'models/consolidation_experiments/%s/' % train_args.experiment_id base_path = base_dir + '%d-layer/' % len(train_args.layer) if not exists(base_dir): mkdir(base_dir) if not exists(base_path): mkdir(base_path) # covariance experimentation """from sklearn.covariance import EmpiricalCovariance def get_cov(ldr, sample_idxs=slice(0, 64), normalize=False): feature_layer = reinit_layers[-1] load_layer(base_path, suffix='-task_0.pth') f1 = compute_features(model, feature_layer, ldr) load_layer(base_path, suffix='-task_1.pth') f2 = compute_features(model, feature_layer, ldr) # subsample f1 = torch.cat(f1)[:,sample_idxs].flatten(start_dim=1) f2 = torch.cat(f2)[:,sample_idxs].flatten(start_dim=1) fcat = torch.cat([f1, f2], dim=1) length = f1.shape[1] cov = EmpiricalCovariance().fit(fcat).covariace_ if normalize: cov = cov ** 2 / (cov ** 2).sum(axis=0)[None, :] / (cov ** 2).sum(axis=1)[:, None] cov1 = cov[:length, :length] cov2 = cov[length:, length:] xcov = cov[:length, length:] return cov1, cov2, xcov def get_kernel_sim(): pass""" # save pretrained parameterization of the layer save_layer(base_path, suffix='-full.pth') # reinitialize the layer for layer in reinit_layers: if type(layer) not in [L2Conv, SuperConv]: layer.reset_parameters() if not train_args.superimpose or train_args.l2: save_layer(base_path, suffix='-reinit.pth') # module training schedule for backward training module_schedule = [ (model.layer4[1].conv2,), (model.layer4[1].conv1,), (model.layer4[0].conv2, model.layer4[0].downsample[0]), (model.layer4[0].conv1,), (model.layer3[1].conv2,), (model.layer3[1].conv1,), (model.layer3[0].conv2, model.layer3[0].downsample[0]), (model.layer3[0].conv1,), (model.layer2[1].conv2,), (model.layer2[1].conv1,), (model.layer2[0].conv2, model.layer2[0].downsample[0]), (model.layer2[0].conv1,), (model.layer1[1].conv2,), (model.layer1[1].conv1,), (model.layer1[0].conv2,), (model.layer1[0].conv1,), (model.conv1,) ] module_schedule = [[m for m in modules if m in reinit_layers] for modules in module_schedule] module_schedule = [modules for modules in module_schedule if len(modules) > 0] accuracies = [] # train separately on each subtask for i, (train_loader, val_loader) in enumerate(zip(train_loaders, val_loaders)): if train_args.train_backward and i > 0: train_layerwise(train_args, model, module_schedule, train_loader, val_loader, device=device) else: train(train_args, model, train_loader, val_loader, device=device, optimize_modules=reinit_layers, multihead=True) if train_args.superimpose: model.superimpose(True) accs = [] accuracies += [accs] for j, test_loader in enumerate(test_loaders): c, t = test(model, test_loader, device=device, multihead=True) acc = (c.sum() / t.sum()).item() accs += [acc] print('Task-%d-trained model accuracy for task %d: %.2f' % (i, j, acc * 100.)) # load superimposed weight into memory to be saved if train_args.superimpose: model.load_superimposed_weight() # save trained layer save_layer(base_path, suffix='-task_%d.pth' % i) if not train_args.incremental: # reinitialize the layer load_layer(base_path, suffix='-reinit.pth') # update regularization weighting scheme if train_args.regularization not in ['none', 'l2'] or train_args.weight_sup_method is not None: collect_l2_weight(model, train_loader, method=train_args.regularization, device=device) # reset weight and component in SuperConv if train_args.superimpose and not train_args.train_backward: model.update_component() # update previous parameterization if conducting l2 penalty elif train_args.l2 and not train_args.superimpose: model.update_previous_params() # consolidate using kernel averaging if not train_args.incremental: print('Consolidating separately trained layers...') """threshold = 0.3 for layer, name in zip(reinit_layers, train_args.layer): w = torch.load(base_path + '%s-task_%d.pth' % (name, 0))['weight'] n_consolidated = torch.ones_like(w) for i in range(1, 5): new_w = torch.load(base_path + '%s-task_%d.pth' % (name, i))['weight'] # TODO normalize by distribution of weights in each layer diff = ((w - new_w) ** 2).sum(axis=(1, 2, 3)) ** (1/2) consolidate = diff < threshold w[consolidate] = w[consolidate] + new_w[consolidate] n_consolidated[consolidate] += 1 perc_consolidated = len(np.where(n_consolidated > 1)[0]) / n_consolidated.flatten().shape[0] print('%.2f %% of weights consolidated for layer %s' % (perc_consolidated * 100., name)) w /= n_consolidated layer.cpu() layer.weight.data[:] = w layer.cuda()""" if not train_args.incremental: for layer, name in zip(reinit_layers, train_args.layer): w = 0 for i in range(len(train_loaders)): w = w + torch.load(base_path + '%s-task_%d.pth' % (name, i))['weight'] layer.weight.data[:] = w.to(device) / 5 # test consolidated layer model.train() consolidated_accs = [] for i, test_loader in enumerate(test_loaders): c, t = test(model, test_loader, device=device, multihead=True) acc = (c.sum() / t.sum()).item() print('Accuracy of consolidated model on task %d: %.2f' % (i, acc * 100.)) consolidated_accs += [acc]
def consolidate_single_task(data_args, train_args, model, device=0): train_loaders, val_loader = get_subset_data_loaders(data_args, train_args.num_samples) reinit_layer, = find_network_modules_by_name(model, [train_args.layer]) # test initial accuracy c, t = test(model, val_loader) pt_accuracy = (c.sum() / t.sum()).item() print('Accuracy of fully trained model: %.2f' % (pt_accuracy * 100.)) def save_layer(save_path): reinit_layer.cpu() torch.save(reinit_layer.state_dict(), save_path) reinit_layer.cuda() def load_layer(load_path): reinit_layer.cpu() reinit_layer.load_state_dict(torch.load(load_path)) reinit_layer.cuda() base_dir = 'models/consolidation_experiments/same_task/' base_path = base_dir + train_args.layer + '-diff_reinit-' # save pretrained parameterization of final layer save_layer(base_path + 'full.pth') # reinit final feature layer reinit_layer.reset_parameters() save_layer(base_path + 'reinit_0.pth') accuracies = [] # train final layer separately on each subset of data for i, loader in enumerate(train_loaders): train(train_args, model, loader, val_loader, device=device, optimize_modules=[reinit_layer]) c, t = test(model, val_loader) accuracies += [(c.sum() / t.sum()).item()] print('Accuracy of model trained on subset %d: %.2f' % (i, accuracies[-1] * 100.)) save_layer(base_path + str(i) + '.pth') if not train_args.incremental: # use different reinitialization #load_layer(base_path + 'reinit.pth') reinit_layer.reset_parameters() save_layer(base_path + 'reinit_%d.pth' % (i + 1)) if not train_args.incremental: # attempt to consolidate separately trained layers into a single representation print('Consolidating separately trained layers...') # 1 - naive averaging state1 = torch.load(base_path + '0.pth') state2 = torch.load(base_path + '1.pth') w = (state1['weight'] + state2['weight']) / 2 reinit_layer.weight.data[:] = w.to(device) # test consolidated model c, t = test(model, val_loader) consolidated_acc = (c.sum() / t.sum()).item() print('Accuracy of consolidated model: %.2f' % (consolidated_acc * 100.))
def train_incr(args: IncrTrainingArgs, model, train_loaders, val_loaders, device=0): # single run-through of all exposures acc_save_path = args.acc_save_path model_save_path = args.model_save_path running_test_results = [[] for _ in range(1, len(train_loaders) + 1)] model.active_outputs = [] # set l2 flag if args.regularization != 'none': args.l2 = True # remove affine layer and stats tracking from all batchnorm layers if args.reset_bn: reset_bn(model) optimize_modules = None if args.superimpose: model.superimpose = apply_module_method_if_exists(model, 'superimpose') model.load_superimposed_weight = apply_module_method_if_exists( model, 'load_superimposed_weight') model.update_component = apply_module_method_if_exists( model, 'update_component') model.scale_supconv_grads = apply_module_method_if_exists( model, 'scale_grad') def build_super_conv(conv): return SuperConv(conv.in_channels, conv.out_channels, conv.kernel_size, bias=conv.bias is not None, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups) optimize_modules = [] for name, module in model.named_modules(): if type(module) == torch.nn.Conv2d: sup_conv = build_super_conv(module).to(module.weight.device) set_torchvision_network_module(model, name, sup_conv) optimize_modules += [sup_conv] optimize_modules += [model.fc] elif args.regularization != 'none': model.update_previous_params = apply_module_method_if_exists( model, 'update_previous_weight') def build_l2_conv(conv): return L2Conv(conv.in_channels, conv.out_channels, conv.kernel_size, bias=conv.bias is not None, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups) optimize_modules = [] for name, module in model.named_modules(): if type(module) == torch.nn.Conv2d: sup_conv = build_l2_conv(module).to(module.weight.device) set_torchvision_network_module(model, name, sup_conv) optimize_modules += [sup_conv] optimize_modules += [model.fc] for i, (train_loader, val_loader) in enumerate(zip(train_loaders, val_loaders)): if args.exposure_reinit: init_state = torch.load( join(args.model_save_dir, append_to_file(model_save_path, 'init'))) model.cpu().load_state_dict(init_state) model.cuda() # update active (used) model outputs # TODO generalize for exposure repetition model.active_outputs += train_loader.classes set_task(model, i) args.acc_save_path = append_to_file(acc_save_path, '-exp%d' % (i + 1)) args.model_save_path = append_to_file(model_save_path, '-exp%d' % (i + 1)) train( args, model, train_loader, *val_loaders[:i + 1], device=device, multihead=args.multihead, fc_only=False, #i > 0 optimize_modules=optimize_modules) # load superimposed weight into memory to be saved if args.superimpose: model.load_superimposed_weight() # update regularization weighting scheme if args.regularization not in ['none', 'l2']: collect_l2_weight(model, train_loader, method=args.regularization, device=device) """ print('Testing over all %d previously learned tasks...' % (i + 1)) mean_acc = total_classes = 0 model.eval() if args.superimpose: model.superimpose(True) for j, test_loader in enumerate(val_loaders[:i+1]): set_task(model, j) correct, total = test(model, test_loader, device=device, multihead=args.multihead) accuracy = correct / total * 100. running_test_results[j] += [accuracy] mean_acc += accuracy.sum() total_classes += len(test_loader.classes) mean_acc = mean_acc / total_classes print("Mean accuracy over all %d previously learned tasks: %.4f" % (i + 1, mean_acc)) """ # update component/previous weight if args.superimpose: model.update_component() elif args.regularization != 'none': model.update_previous_params() """
def main(): data_args, train_args, model_args = parse_args(IncrDataArgs, ExperimentArgs, AllModelArgs) if train_args.batch and not train_args.multihead: train_loader, val_loader, test_loader = get_dataloaders( data_args, load_test=False) else: train_loader, val_loader, test_loader = get_dataloaders_incr( data_args, load_test=False, multihead_batch=train_args.batch) state = None # load pretrained feature extractor if specified if model_args.load_state_path: state = torch.load(model_args.load_state_path) if model_args.arch == 'resnet18': net = resnet18(num_classes=data_args.num_classes, seed=data_args.seed, disable_bn_stats=model_args.disable_bn_stats) if state is not None: state['fc.weight'], state['fc.bias'] = net.fc.weight, net.fc.bias net.load_state_dict(state) elif model_args.arch == 'lrm_resnet18': net = load_lrm(state=state, num_classes=data_args.num_classes, seed=data_args.seed, disable_bn_stats=model_args.disable_bn_stats, n_blocks=model_args.n_blocks, block_size_alpha=model_args.block_size_alpha, route_by_task=model_args.route_by_task, fit_keys=train_args.fit_keys) # save state initialization if we will be reinitializing the model before each new exposure if train_args.exposure_reinit: torch.save( net.state_dict(), join(train_args.model_save_dir, append_to_file(train_args.model_save_path, 'init'))) net.cuda() if train_args.batch: if train_args.multihead: # trains model on batches of data across tasks while enforcing classification predictions to be within task train_batch_multihead(train_args, net, train_loader, val_loader, device=0) np.savez(join(train_args.acc_save_dir, train_args.incr_results_path), entropy=net.get_entropy(), class_div=net.get_class_routing_divergence()) else: train(train_args, net, train_loader, val_loader, device=0, multihead=False) else: train_incr(train_args, net, train_loader, val_loader, device=0)
import torch from experiment_utils.train_models import get_dataloaders, save_model, test, train from experiment_utils.argument_parsing import * from model import resnet18 class LoadModelArgs(ArgumentClass): ARGS = { 'load_model_path': Argument('--load-model-path', type=str, default=None, help='path to model file to load at init') } if __name__ == '__main__': model_args, data_args, train_args = parse_args(LoadModelArgs, DataArgs, TrainingArgs) train_loader, val_loader, test_loader = get_dataloaders(data_args) model = resnet18(num_classes=data_args.num_classes, seed=data_args.seed) if model_args.load_model_path: model.load_state_dict(torch.load(model_args.load_model_path)) model.cuda() train(train_args, model, train_loader, val_loader, device=0)