def main(args):
    print('Dataset: {}, Normal Label: {}, LR: {}'.format(args.dataset, args.label, args.lr))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    model_type = args.model
    if model_type == 'resnet':
        model = utils.get_resnet_model(resnet_type=args.resnet_type)
        if args.dataset in ['rsna3D']:
            model = ResNet3D(model)
    elif model_type == 'timesformer':
        model = utils.get_timesformer_model(mode=args.timesformer_mode)
    model = model.to(device)

    ewc_loss = None

    # Freezing Pre-trained model for EWC
    if args.ewc:
        frozen_model = deepcopy(model).to(device)
        frozen_model.eval()
        utils.freeze_model(frozen_model)
        fisher = torch.load(args.diag_path)
        ewc_loss = EWCLoss(frozen_model, fisher)

    utils.freeze_parameters(model)

    sorted_train_loader, shuffled_train_loader, test_loader = utils.get_loaders(dataset=args.dataset, label_class=args.label,
                                                  batch_size=args.batch_size,
                                                  lookup_tables_paths=(args.train_lookup_table, args.test_lookup_table))
    train_model(model, sorted_train_loader, shuffled_train_loader, test_loader, device, args, ewc_loss)
Beispiel #2
0
def main(args):
    print('Dataset: {}, Label: {}, LR: {}'.format(args.dataset, args.label,
                                                  args.lr))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = utils.get_resnet_model(resnet_type=args.resnet_type)

    # Change last layer
    model.fc = torch.nn.Linear(args.latent_dim_size, 1)

    model = model.to(device)
    utils.freeze_parameters(model, train_fc=True)

    train_loader, test_loader = utils.get_loaders(dataset=args.dataset,
                                                  label_class=args.label,
                                                  batch_size=args.batch_size)
    outliers_loader = utils.get_outliers_loader(args.batch_size)

    train_model(model, train_loader, outliers_loader, test_loader, device,
                args.epochs, args.lr)
Beispiel #3
0
def main(args):
    print('Dataset: {}, Normal Label: {}, LR: {}'.format(
        args.dataset, args.label, args.lr))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    model = utils.get_resnet_model(resnet_type=args.resnet_type)
    model = model.to(device)

    ewc_loss = None

    # Freezing Pre-trained model for EWC
    if args.ewc:
        frozen_model = deepcopy(model).to(device)
        frozen_model.eval()
        utils.freeze_model(frozen_model)
        fisher = torch.load(args.diag_path)
        ewc_loss = EWCLoss(frozen_model, fisher)

    utils.freeze_parameters(model)
    train_loader, test_loader = utils.get_loaders(dataset=args.dataset,
                                                  label_class=args.label,
                                                  batch_size=args.batch_size)
    train_model(model, train_loader, test_loader, device, args, ewc_loss)
Beispiel #4
0
                                   shuffle=False,
                                   num_workers=12)
testloader = torchdata.DataLoader(testset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=12)

if args.test:
    loader = testloader
else:
    loader = trainloader

if dataset == Data.imagenet:
    # rnet = utils.get_resnet_model(dataset, [])
    # rnet.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'))
    rnet = resnet50(pretrained=True)

else:
    rnet = utils.get_resnet_model(dataset, [18, 18, 18])
    rnetpath = os.path.join(utils.get_path(Data.cifar10, iter=args.iter),
                            'rnet.t7')
    rnet_dict = torch.load(rnetpath)
    rnet.load_state_dict(rnet_dict['net'])

if dataset == Data.imagenet and device == 'cuda':
    rnet = torch.nn.DataParallel(rnet)

rnet.to(device)
num_scales = 3
scale_search()
Beispiel #5
0
trainloader = D.DataLoader(trainset,
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=4,
                           pin_memory=True)
testloader = D.DataLoader(testset,
                          batch_size=args.batch_size,
                          shuffle=False,
                          num_workers=4,
                          pin_memory=True)

# global variable recording the best test accuracy
best_test_acc = 0.0

# configuration of the ResNet
rnet = utils.get_resnet_model(dataset_type, [3, 4, 6, 3])
rnet_dict = torch.load(os.path.join(rootpath, 'rnet.t7'))
rnet.load_state_dict(rnet_dict['net'])
num_blocks = sum(rnet.layer_config)

# Policy Network: ResNet-10
agent = Networks.PolicyNet10(num_blocks)
print(num_blocks)

# init weights of agent
start_epoch = 0
if args.resume:
    # test with trained check-point
    print('load the check point weights...')
    ckpt = torch.load(os.path.join(rootpath, 'agent_latest.t7'))
    start_epoch = ckpt['epoch']
Beispiel #6
0
# dataset and dataloader
trainset, testset = utils.get_policy_datasets(Data.mini_imagenet,
                                              Mode.no_policy)
trainloader = D.DataLoader(trainset,
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=6,
                           pin_memory=True)
testloader = D.DataLoader(testset,
                          batch_size=args.batch_size,
                          shuffle=False,
                          num_workers=6,
                          pin_memory=True)

if data_type == Data.mini_imagenet:
    rnet = utils.get_resnet_model(data_type, [3, 4, 6, 3])  # resnet 50
else:
    rnet = utils.get_resnet_model(data_type, [18, 18, 18])  # resnet 110

rnet.to(device)
if torch.cuda.device_count() > 1:
    print('paralleling for multiple GPUs...')
    rnet = nn.DataParallel(rnet)

start_epoch = 0

if args.resume:
    assert os.path.isfile(os.path.join(
        rootpath, 'rnet.t7')), 'Error: no check-point found!'
    ckpt = torch.load(os.path.join(rootpath, 'rnet.t7'))
    rnet.load_state_dict(ckpt['net'])
Beispiel #7
0
def greedy_search(args, grouping_ratio):

    num_cuda = torch.cuda.device_count()
    data_type = utils.str2Data(args.data)

    # training set
    trainset, testset = utils.get_policy_datasets(Data.mini_imagenet, Mode.no_policy, augment_on_training=False)
    trainloader = torchdata.DataLoader(trainset, batch_size=num_cuda * args.batch_size, shuffle=False,
                                       num_workers=4, pin_memory=True)
    testloader = torchdata.DataLoader(testset, batch_size=num_cuda * args.batch_size, shuffle=False,
                                      num_workers=4, pin_memory=True)

    if args.test:
        loader = testloader
    else:
        loader = trainloader

    rnet = utils.get_resnet_model(data_type, [3, 4, 6, 3])  # resnet 50
    rnetpath = os.path.join(utils.get_path(Data.mini_imagenet, iter=args.iter), 'rnet.t7')
    rnet_dict = torch.load(rnetpath)
    rnet.load_state_dict(rnet_dict['net'])
    rnet.to(device)

    if num_cuda > 1:
        print('paralleling for multiple GPUs...')
        rnet = torch.nn.DataParallel(rnet)

    # total number of residual blocks in the ResNet
    num_block = sum(rnet.layer_config)
    rnet.eval().cuda(args.gpu)

    # the greedy policy is encoded as a vector of dropping order.
    # if greedy_search[image_id][i] = j, means the block j is the i-th block to prune.
    greedy_search = np.zeros((len(loader)*args.batch_size, num_block))
    print('Applying Greedy Search on %s dataset ...' % args.data)
    index_counter = 0
    print('Number of Layers: %d' % num_block)
    for batch_idx, (inputs, targets) in tqdm.tqdm(enumerate(loader), total=len(loader)):
        batch_size = inputs.shape[0]
        batch_policy = np.zeros((batch_size, num_block))

        # iteratively dropping residual blocks in the ResNet
        #   Limit is ranging from 1 to num_block, meaning the number of
        #   blocks has to be dropped.
        for limit in range(1, num_block+1):
            # initialize the max_prob with a negative values.
            max_prob = -100 * np.ones(batch_size)
            max_prob_drop_idx = np.zeros(batch_size).astype(np.int)
            drop_idx = np.zeros(batch_size).astype(np.int)

            for i in range(num_block-limit+1):
                # enumerate each remaining block, try to drop it
                # and then compute the prob
                for j in range(batch_size):
                    while batch_policy[j][drop_idx[j]] != 0:
                        drop_idx[j] += 1
                    batch_policy[j][drop_idx[j]] = limit
                p = torch.from_numpy((batch_policy == 0).astype(np.int)).view(batch_size, -1)
                p_var, inputs_var = Variable(p).cuda(args.gpu), Variable(inputs).cuda(args.gpu)
                preds = rnet.forward(inputs_var, p_var)
                for j in range(batch_size):
                    target = targets[j]
                    if preds.data[j][target] > max_prob[j]:
                        max_prob[j] = preds.data[j][target]
                    if preds.data[j][target] >= max_prob[j] * grouping_ratio:
                        max_prob_drop_idx[j] = drop_idx[j]
                    batch_policy[j][drop_idx[j]] = 0
                    drop_idx[j] += 1
            for j in range(batch_size):
                batch_policy[j][max_prob_drop_idx[j]] = limit
        greedy_search[index_counter: index_counter+batch_size] = batch_policy
        index_counter += batch_size

    writedir = utils.get_path(data_type.mini_imagenet, iter=args.iter)
    os.makedirs(writedir, exist_ok=True)
    if args.test:
        f = open(os.path.join(writedir, 'greedy_test.pkl'), 'wb')
    else:
        f = open(os.path.join(writedir, 'greedy_train.pkl'), 'wb')
    pickle.dump(greedy_search, f)
    print('done.')
Beispiel #8
0
            total_loss += loss.item()
            correct_samples += pred.eq(target).sum()
    avg_loss = total_loss / num_batch
    accuracy = 100.0 * correct_samples / total_samples
    if mode == "val" or (eval_type == "test" and mode == "test"):
        loss_history.append(avg_loss)
        acc_history.append(accuracy)
        sys.stdout.write(' %s: %.4f -- %s: %.2f \n' %
                         (mode + "_loss", avg_loss, mode + "_acc", accuracy))
    return avg_loss, accuracy


if __name__ == '__main__':
    if opt.model_type[:6] == "resnet":
        model_name = opt.model_type
        model = get_resnet_model(resnet_type=opt.model_type,
                                 n_classes=opt.n_classes)
    else:
        model_name = get_ViT_name(model_type=opt.model_type,
                                  patch_size=opt.patch_size,
                                  hybrid=opt.hybrid)
        model = get_ViT_model(type=opt.model_type,
                              image_size=opt.image_size,
                              patch_size=opt.patch_size,
                              n_classes=opt.n_classes,
                              n_channels=opt.n_channels,
                              dropout=opt.dropout,
                              hybrid=opt.hybrid)
    output_graph_path, dump_file = get_output_path(
        model_name=model_name,
        root_path=opt.output_root_path,
        dataset_name=opt.dataset_name)
Beispiel #9
0
                                          index=out_df.columns),
                                ignore_index=True)
 for vit in vit_models:
     hybrid_model = get_ViT_model(type=vit,
                                  image_size=224,
                                  patch_size=16,
                                  n_channels=3,
                                  n_classes=10,
                                  dropout=0.,
                                  hybrid=True)
     macs, params = get_model_complexity_info(hybrid_model, (3, 224, 224),
                                              as_strings=True,
                                              print_per_layer_stat=True,
                                              verbose=True)
     data_to_add = ["resnet18+" + vit, str(macs), '{:<8}'.format(params)]
     data_df_scores = np.hstack((np.array(data_to_add).reshape(1, -1)))
     out_df = out_df.append(pd.Series(data_df_scores.reshape(-1),
                                      index=out_df.columns),
                            ignore_index=True)
 for resnet in resnet_models:
     model = get_resnet_model(resnet_type=resnet, n_classes=10)
     macs, params = get_model_complexity_info(model, (3, 224, 224),
                                              as_strings=True,
                                              print_per_layer_stat=True,
                                              verbose=True)
     data_to_add = [resnet, str(macs), '{:<8}'.format(params)]
     data_df_scores = np.hstack((np.array(data_to_add).reshape(1, -1)))
     out_df = out_df.append(pd.Series(data_df_scores.reshape(-1),
                                      index=out_df.columns),
                            ignore_index=True)
 out_df.to_csv(output_cost_csv, index=False, header=True)