Ejemplo n.º 1
0
def train(args, model, device, train_loader, test_loader, optimizer):
    for epoch in range(args.num_pre_epochs):
        print('Pre epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = regularized_nll_loss(args, model, output, target)
            loss.backward()
            optimizer.step()
        test(args, model, device, test_loader)

    Z, U = initialize_Z_and_U(model)
    for epoch in range(args.num_epochs):
        model.train()
        print('Epoch: {}'.format(epoch + 1))
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = admm_loss(args, device, model, Z, U, output, target)
            loss.backward()
            optimizer.step()
        X = update_X(model)
        Z = update_Z_l1(X, U, args) if args.l1 else update_Z(X, U, args)
        U = update_U(U, X, Z)
        print_convergence(model, X, Z)
        test(args, model, device, test_loader)
Ejemplo n.º 2
0
def train(args, model, device, train_loader, test_loader, optimizer):
    train_start = time.time()
    Z, U = initialize_Z_and_U(model, device)
    for epoch in range(args.num_epochs):
        print('Epoch: {}'.format(epoch + 1))
        model.train()
        epoch_start = time.time()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = admm_loss(args, device, model, Z, U, output, target)
            loss.backward()
            optimizer.step()
        epoch_end = time.time()
        print("train epoch time cost: {}".format(epoch_end - epoch_start))
        admm_step_start = time.time()
        X = update_X(model, device)
        Z = update_Z_l1(X, U, args) if args.l1 else update_Z(
            X, U, args, device)
        U = update_U(U, X, Z)
        admm_step_end = time.time()
        print("admm step time cost: {}".format(admm_step_end -
                                               admm_step_start))
        print_convergence(model, X, Z)
        test(args, model, device, test_loader)
    train_end = time.time()
    print("train total time cost: {}".format(train_end - train_start))
def train(args, model, device, train_loader, test_loader, optimizer):
    loss_iter = []
    for epoch in range(args.num_pre_epochs):

        print('Pre epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = regularized_nll_loss(args, model, output, target)
            loss_iter.append(loss)
            loss.backward()
            optimizer.step()
        test(args, model, device, test_loader)

    Z, U = initialize_Z_and_U(model)  #初始化 Z,U
    A = np.zeros((args.idx, args.num_epochs))
    for epoch in range(args.num_epochs):
        model.train()
        print('Epoch: {}'.format(epoch + 1))
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = admm_loss(args, device, model, Z, U, output, target)
            loss.backward()
            optimizer.step()
        X = update_X(model)  #更新X
        #Z的更新根据正则项来选择
        if (args.l1):
            Z = update_Z_l1(X, U, args)
        elif (args.l0):
            Z = update_Z_l0(X, U, args)
        elif (args.SCAD):
            Z = update_Z_SCAD(X, U, args)
        elif (args.rscad):
            print('use rscad updata z')
            Z = updata_Z_Prox_glarho(X, U, args)
        else:
            Z = update_Z(X, U, args)
        #根据稀疏项 选择跟新Z 方式
        U = update_U(U, X, Z)

        if not args.test_lamda:
            a = print_convergence(model, X, Z)
            for i in range(args.idx):
                A[i, epoch] = a[i]

        test(args, model, device, test_loader)
    return A
Ejemplo n.º 4
0
def main(args, layer_train_para, layer_names, layer_kernel_inc, pattern):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    print("Creating model")
    model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)

    # layer_train_para = [
    #     "layer1.0.conv1.weight",
    #     "layer1.0.bn1.weight",
    #     "layer1.0.bn1.bias",
    #     "layer1.0.conv2.weight",
    #     "layer1.0.bn2.weight",
    #     "layer1.0.bn2.bias",
    #     "layer1.1.conv1.weight",
    #     "layer1.1.bn1.weight",
    #     "layer1.1.bn1.bias",
    #     "layer1.1.conv2.weight",
    #     "layer1.1.bn2.weight",
    #     "layer1.1.bn2.bias",
    #     "layer2.0.conv2.weight",
    #     "layer2.0.bn2.weight",
    #     "layer2.0.bn2.bias",
    #     "layer2.0.conv1.weight",
    #     "layer2.0.bn1.weight",
    #     "layer2.0.bn1.bias",
    #     "layer2.0.downsample.0.weight",
    #     "layer2.0.downsample.1.weight",
    #     "layer2.0.downsample.1.bias"]
    #
    # layer_names = [
    #     "layer1.0.conv1",
    #     "layer1.0.conv2",
    #     "layer1.1.conv1",
    #     "layer1.1.conv2",
    #     "layer2.0.conv2",
    #     "layer2.1.conv1",
    #     "layer2.1.conv2"
    # ]
    #
    # layer_kernel_inc = [
    #     # "layer2.0.conv1",
    #     # "layer2.0.downsample.0"
    # ]
    #
    # pattern = {}
    # pattern[0] = torch.tensor([[0, 0, 0],
    #                            [1, 1, 1],
    #                            [1, 1, 1]], dtype=torch.float32)
    #
    # pattern[1] = torch.tensor([[1, 1, 1],
    #                            [1, 1, 1],
    #                            [0, 0, 0]], dtype=torch.float32)
    #
    # pattern[2] = torch.tensor([[1, 1, 0],
    #                            [1, 1, 0],
    #                            [1, 1, 0]], dtype=torch.float32)
    #
    # pattern[3] = torch.tensor([[0, 1, 1],
    #                            [0, 1, 1],
    #                            [0, 1, 1]], dtype=torch.float32)

    layers = {}
    ki_layers = {}
    # for layer_name, layer in model.named_modules():
    for layer_name, layer in model.named_modules():
        if isinstance(layer, nn.Conv2d):
            # if is_same(layer.kernel_size) == 3 and layer.in_channels == 512:
            # if is_same(layer.kernel_size) == 3:
            if layer_name in layer_names:
                # layer_names.append(layer_name)
                layers[layer_name] = layer
            if layer_name in layer_kernel_inc:
                ki_layers[layer_name] = layer

        # print(layer_name)
        # if is_same(layer.kernel_size) == 3 and layer.in_channels==512:
        #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #     mask = torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]], dtype=torch.float32, device=device)
        #     ztNAS_add_kernel_mask(model, layer, layer_name, mask=mask)

    #model = modify_model(model)

    # for name, param in model.named_parameters():
    #     names = [n + "." for n in name.split(".")[:-1]]
    #     if "".join(names)[:-1] not in layer_names:
    #         param.requires_grad = False
    #     else:
    #         break

    for name, param in model.named_parameters():
        if name in layer_train_para:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # for name, param in model.named_parameters():
    #     print(name, param.requires_grad, param.data.shape)

    # print(model)

    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    admm_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      eps=args.adam_epsilon)

    admm_re_train_optimizer = PruneAdam(model.named_parameters(),
                                        lr=args.lr,
                                        eps=args.adam_epsilon)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=args.lr_step_size,
                                                   gamma=args.lr_gamma)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        # for name, param in model.named_parameters():
        #     print(name)git oull
        #     print(param)

        layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                                 device)
        utils.print_prune(model, layer_names, layer_pattern)

        for layer_name in layer_names:
            ztNAS_add_kernel_mask(model,
                                  layers[layer_name],
                                  layer_name,
                                  is_pattern=True,
                                  pattern=layer_pattern[layer_name].to(device))

        # print(model)
        model.to(device)
        evaluate(model, criterion, data_loader_test, device=device)

        # evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.retrain_only:
        epoch = 999
        print("Start re-training")
        start_time = time.time()
        print("=" * 10, "Applying pruning model")
        layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                                 device)
        # utils.print_prune(model, layer_names, layer_pattern)

        for layer_name in layer_names:
            ztNAS_add_kernel_mask(model,
                                  layers[layer_name],
                                  layer_name,
                                  is_pattern=True,
                                  pattern=layer_pattern[layer_name].to(device))

        for layer_name in layer_kernel_inc:
            ztNAS_modify_kernel_shape(model, ki_layers[layer_name], layer_name,
                                      2)

        # print(model)
        model.to(device)
        # evaluate(model, criterion, data_loader_test, device=device)

        print("=" * 10, "Retrain")

        re_train_one_epoch(model, criterion, admm_re_train_optimizer,
                           data_loader, device, epoch, args.print_freq,
                           layer_names, layer_pattern, data_loader_test,
                           args.exploration, args.apex)

        acc1, acc5 = evaluate(model,
                              criterion,
                              data_loader_test,
                              device=device,
                              exploration=args.exploration)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))

        return acc1, acc5

    print("Start training")
    start_time = time.time()

    Z, U = utils.initialize_Z_and_U(model, layer_names)
    rho = args.rho
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        Z, U = train_one_epoch(model, criterion, admm_optimizer, data_loader,
                               device, epoch, args.print_freq, layer_names,
                               percent, pattern, Z, U, rho, args.apex)

        rho = rho * 10
        lr_scheduler.step()

        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

        evaluate(model, criterion, data_loader_test, device=device)

    print("=" * 10, "Applying pruning model")
    layer_pattern = utils.get_layers_pattern(model, layer_names, pattern,
                                             device)
    # utils.print_prune(model, layer_names, layer_pattern)

    for layer_name in layer_names:
        ztNAS_add_kernel_mask(model,
                              layers[layer_name],
                              layer_name,
                              is_pattern=True,
                              pattern=layer_pattern[layer_name].to(device))

    # print(model)
    model.to(device)
    # evaluate(model, criterion, data_loader_test, device=device)

    print("=" * 10, "Retrain")

    re_train_one_epoch(model, criterion, admm_re_train_optimizer, data_loader,
                       device, epoch, args.print_freq, layer_names,
                       layer_pattern, data_loader_test, args.exploration,
                       args.apex)

    evaluate(model, criterion, data_loader_test, device=device)

    if args.output_dir:
        checkpoint = {
            'model': model_without_ddp.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch + 1,
            'args': args
        }
        utils.save_on_master(
            checkpoint,
            os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
        utils.save_on_master(checkpoint,
                             os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))