Example #1
0
def train_layerwise(train_args, model, module_schedule, train_loader, val_loader, device=0):
    train_modules = []
    # modules passed in list will be added to optimization set in reverse order
    for modules in module_schedule:
        for m in modules:
            if hasattr(m, 'update_component'):
                m.update_component()
        train_modules += list(modules)
        train_args.epochs = len(train_modules)
        train(train_args, model, train_loader, val_loader, device=device, optimize_modules=train_modules,
              multihead=True)
Example #2
0
def consolidate_multi_task(data_args, train_args, model, device=0):
    if train_args.regularization != 'none':
        train_args.l2 = True
    train_loaders, val_loaders, test_loaders = get_dataloaders_incr(data_args, load_test=True)
    _, _, test_ldr = get_dataloaders(data_args, load_train=False)

    reinit_layers = find_network_modules_by_name(model, train_args.layer)

    if train_args.superimpose:

        # define SuperConv model-wise apply methods
        model.superimpose = apply_module_method_if_exists(model, 'superimpose')
        model.load_superimposed_weight = apply_module_method_if_exists(model, 'load_superimposed_weight')
        model.update_component = apply_module_method_if_exists(model, 'update_component')
        model.scale_supconv_grads = apply_module_method_if_exists(model, 'scale_grad')

        def build_super_conv(conv):
            in_ch = conv.in_channels // train_args.redundant_groups * train_args.redundant_groups
            out_ch = conv.out_channels // train_args.redundant_groups * train_args.redundant_groups
            return SuperConv(in_ch, out_ch, conv.kernel_size, bias=conv.bias is not None,
                             stride=conv.stride, padding=conv.padding, dilation=conv.dilation,
                             groups=conv.groups * train_args.redundant_groups, drop_groups=train_args.drop_groups,
                             bias_sup=train_args.weight_sup_method)

        for i, layer_name in enumerate(train_args.layer):
            old_conv = reinit_layers[i]
            if type(old_conv) != nn.Conv2d:
                continue
            sup_conv = build_super_conv(old_conv)
            set_torchvision_network_module(model, layer_name, sup_conv)
            sup_conv.cuda()
            reinit_layers[i] = sup_conv

    elif train_args.l2:
        model.update_previous_params = apply_module_method_if_exists(model, 'update_previous_weight')

        def build_l2_conv(conv):
            in_ch = conv.in_channels // train_args.redundant_groups * train_args.redundant_groups
            out_ch = conv.out_channels // train_args.redundant_groups * train_args.redundant_groups
            return L2Conv(in_ch, out_ch, conv.kernel_size, bias=conv.bias is not None,
                          stride=conv.stride, padding=conv.padding, dilation=conv.dilation,
                          groups=conv.groups * train_args.redundant_groups)

        for i, layer_name in enumerate(train_args.layer):
            old_conv = reinit_layers[i]
            if type(old_conv) != nn.Conv2d:
                continue
            l2_conv = build_l2_conv(old_conv)
            set_torchvision_network_module(model, layer_name, l2_conv)
            l2_conv.cuda()
            reinit_layers[i] = l2_conv

    # disable affine and running stats of retrained bn layers
    """model.bn1 = nn.BatchNorm2d(model.bn1.num_features, affine=False).cuda()
    model.layer1[0].bn1 = nn.BatchNorm2d(model.layer1[0].bn1.num_features, affine=False).cuda()
    model.layer1[0].bn2 = nn.BatchNorm2d(model.layer1[0].bn2.num_features, affine=False).cuda()
    model.layer1[1].bn1 = nn.BatchNorm2d(model.layer1[1].bn1.num_features, affine=False).cuda()
    model.layer1[1].bn2 = nn.BatchNorm2d(model.layer1[1].bn2.num_features, affine=False).cuda()
    model.layer2[0].bn1 = nn.BatchNorm2d(model.layer2[0].bn1.num_features, affine=False).cuda()
    model.layer2[0].bn2 = nn.BatchNorm2d(model.layer2[0].bn2.num_features, affine=False).cuda()
    model.layer2[0].downsample[1] = nn.BatchNorm2d(model.layer2[0].downsample[1].num_features, affine=False).cuda()
    model.layer2[1].bn1 = nn.BatchNorm2d(model.layer2[1].bn1.num_features, affine=False).cuda()
    model.layer2[1].bn2 = nn.BatchNorm2d(model.layer2[1].bn2.num_features, affine=False).cuda()
    model.layer3[0].bn1 = nn.BatchNorm2d(model.layer3[0].bn1.num_features, affine=False).cuda()
    model.layer3[0].bn2 = nn.BatchNorm2d(model.layer3[0].bn2.num_features, affine=False).cuda()
    model.layer3[0].downsample[1] = nn.BatchNorm2d(model.layer3[0].downsample[1].num_features, affine=False).cuda()
    model.layer3[1].bn1 = nn.BatchNorm2d(model.layer3[1].bn1.num_features, affine=False).cuda()
    model.layer3[1].bn2 = nn.BatchNorm2d(model.layer3[1].bn2.num_features, affine=False).cuda()"""
    model.layer4[0].bn1 = nn.BatchNorm2d(model.layer4[0].bn1.num_features, affine=False).cuda()
    model.layer4[0].bn2 = nn.BatchNorm2d(model.layer4[0].bn2.num_features, affine=False).cuda()
    model.layer4[0].downsample[1] = nn.BatchNorm2d(model.layer4[0].downsample[1].num_features, affine=False).cuda()
    model.layer4[1].bn1 = nn.BatchNorm2d(model.layer4[1].bn1.num_features, affine=False).cuda()
    model.layer4[1].bn2 = nn.BatchNorm2d(model.layer4[1].bn2.num_features, affine=False).cuda()

    model.eval()
    # if not updating bn layer during training, disable model's train mode
    if not train_args.fit_bn_stats:
        model.train = lambda *args, **kwargs: None

    # test pretrained model accuracy
    """pt_accuracies = []
    for i, test_loader in enumerate(test_loaders):
        c, t = test(model, test_loader, device=device, multihead=True)
        acc = (c.sum() / t.sum()).item()
        print('Pretrained model accuracy for task %d: %.2f' % (i, acc * 100.))
        pt_accuracies += [acc]"""

    def save_layer(save_path, suffix='.pth'):
        for layer, name in zip(reinit_layers, train_args.layer):
            layer.cpu()
            torch.save(layer.state_dict(), save_path + name + suffix)
            layer.cuda()

    def load_layer(load_path, suffix='.pth'):
        for layer, name in zip(reinit_layers, train_args.layer):
            layer.cpu()
            layer.load_state_dict(torch.load(load_path + name + suffix))
            layer.cuda()

    base_dir = 'models/consolidation_experiments/%s/' % train_args.experiment_id
    base_path = base_dir + '%d-layer/' % len(train_args.layer)

    if not exists(base_dir):
        mkdir(base_dir)

    if not exists(base_path):
        mkdir(base_path)

    # covariance experimentation
    """from sklearn.covariance import EmpiricalCovariance

    def get_cov(ldr, sample_idxs=slice(0, 64), normalize=False):
        feature_layer = reinit_layers[-1]
        load_layer(base_path, suffix='-task_0.pth')
        f1 = compute_features(model, feature_layer, ldr)
        load_layer(base_path, suffix='-task_1.pth')
        f2 = compute_features(model, feature_layer, ldr)

        # subsample
        f1 = torch.cat(f1)[:,sample_idxs].flatten(start_dim=1)
        f2 = torch.cat(f2)[:,sample_idxs].flatten(start_dim=1)
        fcat = torch.cat([f1, f2], dim=1)

        length = f1.shape[1]

        cov = EmpiricalCovariance().fit(fcat).covariace_
        
        if normalize:
            cov = cov ** 2 / (cov ** 2).sum(axis=0)[None, :] / (cov ** 2).sum(axis=1)[:, None]
        
        cov1 = cov[:length, :length]
        cov2 = cov[length:, length:]
        xcov = cov[:length, length:]

        return cov1, cov2, xcov

    def get_kernel_sim():
        pass"""

    # save pretrained parameterization of the layer
    save_layer(base_path, suffix='-full.pth')

    # reinitialize the layer
    for layer in reinit_layers:
        if type(layer) not in [L2Conv, SuperConv]:
            layer.reset_parameters()
    if not train_args.superimpose or train_args.l2:
        save_layer(base_path, suffix='-reinit.pth')

    # module training schedule for backward training
    module_schedule = [
        (model.layer4[1].conv2,),
        (model.layer4[1].conv1,),
        (model.layer4[0].conv2, model.layer4[0].downsample[0]),
        (model.layer4[0].conv1,),
        (model.layer3[1].conv2,),
        (model.layer3[1].conv1,),
        (model.layer3[0].conv2, model.layer3[0].downsample[0]),
        (model.layer3[0].conv1,),
        (model.layer2[1].conv2,),
        (model.layer2[1].conv1,),
        (model.layer2[0].conv2, model.layer2[0].downsample[0]),
        (model.layer2[0].conv1,),
        (model.layer1[1].conv2,),
        (model.layer1[1].conv1,),
        (model.layer1[0].conv2,),
        (model.layer1[0].conv1,),
        (model.conv1,)
    ]
    module_schedule = [[m for m in modules if m in reinit_layers] for modules in module_schedule]
    module_schedule = [modules for modules in module_schedule if len(modules) > 0]

    accuracies = []
    # train separately on each subtask
    for i, (train_loader, val_loader) in enumerate(zip(train_loaders, val_loaders)):
        if train_args.train_backward and i > 0:
            train_layerwise(train_args, model, module_schedule, train_loader, val_loader, device=device)
        else:
            train(train_args, model, train_loader, val_loader, device=device, optimize_modules=reinit_layers,
                  multihead=True)
        if train_args.superimpose:
            model.superimpose(True)
        accs = []
        accuracies += [accs]
        for j, test_loader in enumerate(test_loaders):
            c, t = test(model, test_loader, device=device, multihead=True)
            acc = (c.sum() / t.sum()).item()
            accs += [acc]
            print('Task-%d-trained model accuracy for task %d: %.2f' % (i, j, acc * 100.))

        # load superimposed weight into memory to be saved
        if train_args.superimpose:
            model.load_superimposed_weight()

        # save trained layer
        save_layer(base_path, suffix='-task_%d.pth' % i)

        if not train_args.incremental:
            # reinitialize the layer
            load_layer(base_path, suffix='-reinit.pth')

        # update regularization weighting scheme
        if train_args.regularization not in ['none', 'l2'] or train_args.weight_sup_method is not None:
            collect_l2_weight(model, train_loader, method=train_args.regularization, device=device)

        # reset weight and component in SuperConv
        if train_args.superimpose and not train_args.train_backward:
            model.update_component()

        # update previous parameterization if conducting l2 penalty
        elif train_args.l2 and not train_args.superimpose:
            model.update_previous_params()

    # consolidate using kernel averaging
    if not train_args.incremental:
        print('Consolidating separately trained layers...')

    """threshold = 0.3
    for layer, name in zip(reinit_layers, train_args.layer):
        w = torch.load(base_path + '%s-task_%d.pth' % (name, 0))['weight']
        n_consolidated = torch.ones_like(w)
        for i in range(1, 5):
            new_w = torch.load(base_path + '%s-task_%d.pth' % (name, i))['weight']
            # TODO normalize by distribution of weights in each layer
            diff = ((w - new_w) ** 2).sum(axis=(1, 2, 3)) ** (1/2)
            consolidate = diff < threshold
            w[consolidate] = w[consolidate] + new_w[consolidate]
            n_consolidated[consolidate] += 1

        perc_consolidated = len(np.where(n_consolidated > 1)[0]) / n_consolidated.flatten().shape[0]
        print('%.2f %% of weights consolidated for layer %s' % (perc_consolidated * 100., name))

        w /= n_consolidated
        layer.cpu()
        layer.weight.data[:] = w
        layer.cuda()"""

    if not train_args.incremental:
        for layer, name in zip(reinit_layers, train_args.layer):
            w = 0
            for i in range(len(train_loaders)):
                w = w + torch.load(base_path + '%s-task_%d.pth' % (name, i))['weight']

            layer.weight.data[:] = w.to(device) / 5

    # test consolidated layer
    model.train()
    consolidated_accs = []
    for i, test_loader in enumerate(test_loaders):
        c, t = test(model, test_loader, device=device, multihead=True)
        acc = (c.sum() / t.sum()).item()
        print('Accuracy of consolidated model on task %d: %.2f' % (i, acc * 100.))
        consolidated_accs += [acc]
Example #3
0
def consolidate_single_task(data_args, train_args, model, device=0):
    train_loaders, val_loader = get_subset_data_loaders(data_args, train_args.num_samples)

    reinit_layer, = find_network_modules_by_name(model, [train_args.layer])

    # test initial accuracy
    c, t = test(model, val_loader)
    pt_accuracy = (c.sum() / t.sum()).item()
    print('Accuracy of fully trained model: %.2f' % (pt_accuracy * 100.))

    def save_layer(save_path):
        reinit_layer.cpu()
        torch.save(reinit_layer.state_dict(), save_path)
        reinit_layer.cuda()

    def load_layer(load_path):
        reinit_layer.cpu()
        reinit_layer.load_state_dict(torch.load(load_path))
        reinit_layer.cuda()

    base_dir = 'models/consolidation_experiments/same_task/'
    base_path = base_dir + train_args.layer + '-diff_reinit-'

    # save pretrained parameterization of final layer
    save_layer(base_path + 'full.pth')

    # reinit final feature layer
    reinit_layer.reset_parameters()
    save_layer(base_path + 'reinit_0.pth')

    accuracies = []

    # train final layer separately on each subset of data
    for i, loader in enumerate(train_loaders):
        train(train_args, model, loader, val_loader, device=device, optimize_modules=[reinit_layer])
        c, t = test(model, val_loader)
        accuracies += [(c.sum() / t.sum()).item()]
        print('Accuracy of model trained on subset %d: %.2f' % (i, accuracies[-1] * 100.))

        save_layer(base_path + str(i) + '.pth')

        if not train_args.incremental:
            # use different reinitialization
            #load_layer(base_path + 'reinit.pth')
            reinit_layer.reset_parameters()
            save_layer(base_path + 'reinit_%d.pth' % (i + 1))

    if not train_args.incremental:
        # attempt to consolidate separately trained layers into a single representation
        print('Consolidating separately trained layers...')

        # 1 - naive averaging
        state1 = torch.load(base_path + '0.pth')
        state2 = torch.load(base_path + '1.pth')

        w = (state1['weight'] + state2['weight']) / 2
        reinit_layer.weight.data[:] = w.to(device)

    # test consolidated model
    c, t = test(model, val_loader)
    consolidated_acc = (c.sum() / t.sum()).item()
    print('Accuracy of consolidated model: %.2f' % (consolidated_acc * 100.))
Example #4
0
def train_incr(args: IncrTrainingArgs,
               model,
               train_loaders,
               val_loaders,
               device=0):
    # single run-through of all exposures
    acc_save_path = args.acc_save_path
    model_save_path = args.model_save_path
    running_test_results = [[] for _ in range(1, len(train_loaders) + 1)]
    model.active_outputs = []

    # set l2 flag
    if args.regularization != 'none':
        args.l2 = True

    # remove affine layer and stats tracking from all batchnorm layers
    if args.reset_bn:
        reset_bn(model)

    optimize_modules = None
    if args.superimpose:
        model.superimpose = apply_module_method_if_exists(model, 'superimpose')
        model.load_superimposed_weight = apply_module_method_if_exists(
            model, 'load_superimposed_weight')
        model.update_component = apply_module_method_if_exists(
            model, 'update_component')
        model.scale_supconv_grads = apply_module_method_if_exists(
            model, 'scale_grad')

        def build_super_conv(conv):
            return SuperConv(conv.in_channels,
                             conv.out_channels,
                             conv.kernel_size,
                             bias=conv.bias is not None,
                             stride=conv.stride,
                             padding=conv.padding,
                             dilation=conv.dilation,
                             groups=conv.groups)

        optimize_modules = []
        for name, module in model.named_modules():
            if type(module) == torch.nn.Conv2d:
                sup_conv = build_super_conv(module).to(module.weight.device)
                set_torchvision_network_module(model, name, sup_conv)
                optimize_modules += [sup_conv]
        optimize_modules += [model.fc]

    elif args.regularization != 'none':
        model.update_previous_params = apply_module_method_if_exists(
            model, 'update_previous_weight')

        def build_l2_conv(conv):
            return L2Conv(conv.in_channels,
                          conv.out_channels,
                          conv.kernel_size,
                          bias=conv.bias is not None,
                          stride=conv.stride,
                          padding=conv.padding,
                          dilation=conv.dilation,
                          groups=conv.groups)

        optimize_modules = []
        for name, module in model.named_modules():
            if type(module) == torch.nn.Conv2d:
                sup_conv = build_l2_conv(module).to(module.weight.device)
                set_torchvision_network_module(model, name, sup_conv)
                optimize_modules += [sup_conv]
        optimize_modules += [model.fc]

    for i, (train_loader,
            val_loader) in enumerate(zip(train_loaders, val_loaders)):
        if args.exposure_reinit:
            init_state = torch.load(
                join(args.model_save_dir,
                     append_to_file(model_save_path, 'init')))
            model.cpu().load_state_dict(init_state)
            model.cuda()

        # update active (used) model outputs
        # TODO generalize for exposure repetition
        model.active_outputs += train_loader.classes
        set_task(model, i)

        args.acc_save_path = append_to_file(acc_save_path, '-exp%d' % (i + 1))
        args.model_save_path = append_to_file(model_save_path,
                                              '-exp%d' % (i + 1))
        train(
            args,
            model,
            train_loader,
            *val_loaders[:i + 1],
            device=device,
            multihead=args.multihead,
            fc_only=False,  #i > 0
            optimize_modules=optimize_modules)

        # load superimposed weight into memory to be saved
        if args.superimpose:
            model.load_superimposed_weight()

        # update regularization weighting scheme
        if args.regularization not in ['none', 'l2']:
            collect_l2_weight(model,
                              train_loader,
                              method=args.regularization,
                              device=device)
        """
        print('Testing over all %d previously learned tasks...' % (i + 1))
        mean_acc = total_classes = 0
        model.eval()
        if args.superimpose:
            model.superimpose(True)
        for j, test_loader in enumerate(val_loaders[:i+1]):
            set_task(model, j)

            correct, total = test(model, test_loader, device=device, multihead=args.multihead)
            accuracy = correct / total * 100.
            running_test_results[j] += [accuracy]
            mean_acc += accuracy.sum()
            total_classes += len(test_loader.classes)
        mean_acc = mean_acc / total_classes
        print("Mean accuracy over all %d previously learned tasks: %.4f" % (i + 1, mean_acc))
        """

        # update component/previous weight
        if args.superimpose:
            model.update_component()
        elif args.regularization != 'none':
            model.update_previous_params()
        """
Example #5
0
def main():
    data_args, train_args, model_args = parse_args(IncrDataArgs,
                                                   ExperimentArgs,
                                                   AllModelArgs)
    if train_args.batch and not train_args.multihead:
        train_loader, val_loader, test_loader = get_dataloaders(
            data_args, load_test=False)
    else:
        train_loader, val_loader, test_loader = get_dataloaders_incr(
            data_args, load_test=False, multihead_batch=train_args.batch)

    state = None
    # load pretrained feature extractor if specified
    if model_args.load_state_path:
        state = torch.load(model_args.load_state_path)

    if model_args.arch == 'resnet18':
        net = resnet18(num_classes=data_args.num_classes,
                       seed=data_args.seed,
                       disable_bn_stats=model_args.disable_bn_stats)
        if state is not None:
            state['fc.weight'], state['fc.bias'] = net.fc.weight, net.fc.bias
            net.load_state_dict(state)
    elif model_args.arch == 'lrm_resnet18':
        net = load_lrm(state=state,
                       num_classes=data_args.num_classes,
                       seed=data_args.seed,
                       disable_bn_stats=model_args.disable_bn_stats,
                       n_blocks=model_args.n_blocks,
                       block_size_alpha=model_args.block_size_alpha,
                       route_by_task=model_args.route_by_task,
                       fit_keys=train_args.fit_keys)

    # save state initialization if we will be reinitializing the model before each new exposure
    if train_args.exposure_reinit:
        torch.save(
            net.state_dict(),
            join(train_args.model_save_dir,
                 append_to_file(train_args.model_save_path, 'init')))
    net.cuda()

    if train_args.batch:
        if train_args.multihead:
            # trains model on batches of data across tasks while enforcing classification predictions to be within task
            train_batch_multihead(train_args,
                                  net,
                                  train_loader,
                                  val_loader,
                                  device=0)
            np.savez(join(train_args.acc_save_dir,
                          train_args.incr_results_path),
                     entropy=net.get_entropy(),
                     class_div=net.get_class_routing_divergence())
        else:
            train(train_args,
                  net,
                  train_loader,
                  val_loader,
                  device=0,
                  multihead=False)
    else:
        train_incr(train_args, net, train_loader, val_loader, device=0)
Example #6
0
import torch
from experiment_utils.train_models import get_dataloaders, save_model, test, train
from experiment_utils.argument_parsing import *
from model import resnet18


class LoadModelArgs(ArgumentClass):
    ARGS = {
        'load_model_path':
        Argument('--load-model-path',
                 type=str,
                 default=None,
                 help='path to model file to load at init')
    }


if __name__ == '__main__':
    model_args, data_args, train_args = parse_args(LoadModelArgs, DataArgs,
                                                   TrainingArgs)
    train_loader, val_loader, test_loader = get_dataloaders(data_args)
    model = resnet18(num_classes=data_args.num_classes, seed=data_args.seed)
    if model_args.load_model_path:
        model.load_state_dict(torch.load(model_args.load_model_path))
    model.cuda()
    train(train_args, model, train_loader, val_loader, device=0)