예제 #1
0
 def __init__(self):
     self.skeleton = []
     self.KPV = KeyPointVisualizer()
     pose_dataset = Mscoco()
     if config.fast_inference:
         self.pose_model = InferenNet_fast(4 * 1 + 1, pose_dataset)
     else:
         self.pose_model = InferenNet(4 * 1 + 1, pose_dataset)
     # self.pose_model = createModel(cfg=model_cfg)
     if config.device != "cpu":
         self.pose_model.cuda()
         self.pose_model.eval()
     self.batch_size = config.pose_batch
     flops = print_model_param_flops(self.pose_model)
     print("The flops of current pose estimation model is {}".format(flops))
예제 #2
0
def write_info(m, metric, string):
    params = print_model_param_nums(m)
    flops = print_model_param_flops(m)
    if string == "origin":
        f.write(('\n' + '%50s' * 1) % "origin")
    else:
        f.write(('\n' + '%50s' * 1) % ("{}-{}".format(string, folder_str)))
    f.write(('%15s' * 1) % ("{}".format(flops)))
    processed_metric = [round(m, 4) for m in metric[0]]
    inf_time, _ = obtain_avg_forward_time(random_input, m)
    inf_time = round(inf_time, 4)
    f.write(('%10s' * 9) % (
        "{}".format(inf_time), "{}".format(params), "{}".format(processed_metric[0]),
        "{}".format(processed_metric[1]),
        "{}".format(processed_metric[2]), "{}".format(processed_metric[3]), "{}".format(processed_metric[4]),
        "{}".format(processed_metric[5]), "{}\n".format(processed_metric[6]),))
예제 #3
0
    eval_model = lambda model: test(model=model,
                                    cfg=opt.cfg,
                                    data=opt.data,
                                    batch_size=16,
                                    img_size=img_size,
                                    conf_thres=0.1)
    obtain_num_parameters = lambda model: sum(
        [param.nelement() for param in model.parameters()])

    print("\nTesting model {}:".format(model_name))
    with torch.no_grad():
        metric, _ = eval_model(model)
    print(metric)
    metric = [round(m, 4) for m in metric]

    flops = print_model_param_flops(model)
    params = print_model_param_nums(model)

    random_input = torch.rand((1, 3, img_size, img_size)).to(device)
    forward_time, _ = obtain_avg_forward_time(random_input, model)
    forward_time = round(forward_time, 4)

    file.write(('\n' + '%70s' * 1) % ("{}".format(model_name)))
    file.write(('%15s' * 1) % ("{}".format(flops)))
    file.write(('%10s' * 9) % (
        "{}".format(forward_time),
        "{}".format(params),
        "{}".format(metric[0]),
        "{}".format(metric[1]),
        "{}".format(metric[2]),
        "{}".format(metric[3]),
예제 #4
0
def main(config):
    stats = {}
    device = 'cuda'
    criterion = torch.nn.CrossEntropyLoss()

    # config = init_config() if config is None else config
    logger, writer = init_summary_writer(config)
    trainloader, testloader = init_dataloader(config)
    net, bottleneck_net = init_network(config, logger, device)
    pruner = init_pruner(net, bottleneck_net, config, writer, logger)

    # start pruning
    epochs = str_to_list(config.epoch, ',', int)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    ratios = str_to_list(config.ratio, ',', float)

    fisher_type = config.fisher_type  # empirical|true
    fisher_mode = config.fisher_mode  # eigen|full|diagonal
    normalize = config.normalize
    prune_mode = config.prune_mode  # one-pass | iterative
    fix_rotation = config.get('fix_rotation', True)

    assert (len(epochs) == len(learning_rates)
            and len(learning_rates) == len(weight_decays)
            and len(weight_decays) == len(ratios))

    total_parameters = count_parameters(net.train())
    for it in range(len(epochs)):
        epoch = epochs[it]
        lr = learning_rates[it]
        wd = weight_decays[it]
        ratio = ratios[it]
        logger.info('-' * 120)
        logger.info('** [%d], Ratio: %.2f, epoch: %d, lr: %.4f, wd: %.4f' %
                    (it, ratio, epoch, lr, wd))
        logger.info(
            'Reinit: %s, Fisher_mode: %s, fisher_type: %s, normalize: %s, fix_rotation: %s.'
            % (config.re_init, fisher_mode, fisher_type, normalize,
               fix_rotation))
        pruner.fix_rotation = fix_rotation

        # conduct pruning
        cfg = pruner.make_pruned_model(trainloader,
                                       criterion=criterion,
                                       device=device,
                                       fisher_type=fisher_type,
                                       prune_ratio=ratio,
                                       normalize=normalize,
                                       re_init=config.re_init)

        # for tracking the best accuracy
        compression_ratio, unfair_ratio, all_numel, rotation_numel = compute_ratio(
            pruner.model, total_parameters, fix_rotation, logger)
        if config.dataset == 'tiny_imagenet':
            total_flops, rotation_flops = print_model_param_flops(pruner.model,
                                                                  64,
                                                                  cuda=True)
        else:
            total_flops, rotation_flops = print_model_param_flops(pruner.model,
                                                                  32,
                                                                  cuda=True)
        train_loss_pruned, train_acc_pruned = pruner.test_model(
            trainloader, criterion, device)
        test_loss_pruned, test_acc_pruned = pruner.test_model(
            testloader, criterion, device)

        # write results
        logger.info('Before: Accuracy: %.2f%%(train), %.2f%%(test).' %
                    (train_acc_pruned, test_acc_pruned))
        logger.info('        Loss:     %.2f  (train), %.2f  (test).' %
                    (train_loss_pruned, test_loss_pruned))

        test_loss_finetuned, test_acc_finetuned = pruner.fine_tune_model(
            trainloader=trainloader,
            testloader=testloader,
            criterion=criterion,
            optim=optim,
            learning_rate=lr,
            weight_decay=wd,
            nepochs=epoch)
        train_loss_finetuned, train_acc_finetuned = pruner.test_model(
            trainloader, criterion, device)
        logger.info('After:  Accuracy: %.2f%%(train), %.2f%%(test).' %
                    (train_acc_finetuned, test_acc_finetuned))
        logger.info('        Loss:     %.2f  (train), %.2f  (test).' %
                    (train_loss_finetuned, test_loss_finetuned))
        # save model

        stat = {
            'total_flops': total_flops,
            'rotation_flops': rotation_flops,
            'it': it,
            'prune_ratio': ratio,
            'cr': compression_ratio,
            'unfair_cr': unfair_ratio,
            'all_params': all_numel,
            'rotation_params': rotation_numel,
            'prune/train_loss': train_loss_pruned,
            'prune/train_acc': train_acc_pruned,
            'prune/test_loss': test_loss_pruned,
            'prune/test_acc': test_acc_pruned,
            'finetune/train_loss': train_loss_finetuned,
            'finetune/test_loss': test_loss_finetuned,
            'finetune/train_acc': train_acc_finetuned,
            'finetune/test_acc': test_acc_finetuned
        }
        save_model(config, it, pruner, cfg, stat)

        stats[it] = stat

        if prune_mode == 'one_pass':
            del net
            del pruner
            net, bottleneck_net = init_network(config, logger, device)
            pruner = init_pruner(net, bottleneck_net, config, writer, logger)
            pruner.iter = it
        with open(os.path.join(config.summary_dir, 'stats.json'), 'w') as f:
            json.dump(stats, f)
예제 #5
0
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    print(args.dataset, args.network, args.depth)
    print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' %
          (start_epoch, best_acc))
    raise Exception('Test for Acc.')

# init summary writter
log_dir = os.path.join(args.log_dir,
                       '%s_%s%s' % (args.dataset, args.network, args.depth))
makedirs(log_dir)
writer = SummaryWriter(log_dir)

if args.dataset == 'tiny_imagenet':
    total_flops, rotation_flops = print_model_param_flops(net, 64, cuda=True)
elif args.dataset == 'imagenet':
    total_flops, rotation_flops = print_model_param_flops(net, 224, cuda=True)
else:
    total_flops, rotation_flops = print_model_param_flops(net, 32, cuda=True)
num_params = count_parameters(net)
print(f"Total Flops: {total_flops}")
print(f"Total Params: {num_params}")


def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
예제 #6
0
def main():
    # Prepare Dataset

    train_dataset = MyDataset(config.train_info, train=True)
    val_dataset = MyDataset(config.train_info, train=False)
    # for k, v in config.train_info.items():
    #     pass
    # train_dataset = Mscoco(v, train=True)
    # val_dataset = Mscoco(v, train=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.train_batch,
        shuffle=True,
        num_workers=config.train_mum_worker,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config.val_batch,
                                             shuffle=True,
                                             num_workers=config.val_num_worker,
                                             pin_memory=True)

    # for k, v in config.train_info.items():
    #     train_dataset = Mscoco([v[0], v[1]], train=True, val_img_num=v[2])
    #     val_dataset = Mscoco([v[0], v[1]], train=False, val_img_num=v[2])
    #
    # train_loaders[k] = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #     pin_memory=True)
    #
    # val_loaders[k] = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)
    #
    # train_loader = torch.utils.data.DataLoader(
    #         train_dataset, batch_size=config.train_batch, shuffle=True, num_workers=config.train_mum_worker,
    #         pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(
    #         val_dataset, batch_size=config.val_batch, shuffle=False, num_workers=config.val_num_worker, pin_memory=True)

    # assert train_loaders != {}, "Your training data has not been specific! "

    # Model Initialize
    if device != "cpu":
        m = createModel(cfg=model_cfg).cuda()
    else:
        m = createModel(cfg=model_cfg).cpu()

    begin_epoch = 0
    pre_train_model = config.loadModel
    flops = print_model_param_flops(m)
    print("FLOPs of current model is {}".format(flops))
    params = print_model_param_nums(m)
    print("Parameters of current model is {}".format(params))

    if pre_train_model:
        print('Loading Model from {}'.format(pre_train_model))
        m.load_state_dict(torch.load(pre_train_model))
        opt.trainIters = config.train_batch * (begin_epoch - 1)
        opt.valIters = config.val_batch * (begin_epoch - 1)
        begin_epoch = int(pre_train_model.split("_")[-1][:-4]) + 1
        os.makedirs("exp/{}/{}".format(dataset, save_folder), exist_ok=True)
    else:
        print('Create new model')
        with open("log/{}.txt".format(save_folder), "a+") as f:
            f.write("FLOPs of current model is {}\n".format(flops))
            f.write("Parameters of current model is {}\n".format(params))
        if not os.path.exists("exp/{}/{}".format(dataset, save_folder)):
            try:
                os.mkdir("exp/{}/{}".format(dataset, save_folder))
            except FileNotFoundError:
                os.mkdir("exp/{}".format(dataset))
                os.mkdir("exp/{}/{}".format(dataset, save_folder))

    if optimize == 'rmsprop':
        optimizer = torch.optim.RMSprop(m.parameters(),
                                        lr=config.lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weightDecay)
    elif optimize == 'adam':
        optimizer = torch.optim.Adam(m.parameters(),
                                     lr=config.lr,
                                     weight_decay=config.weightDecay)
    else:
        raise Exception

    if mix_precision:
        m, optimizer = amp.initialize(m, optimizer, opt_level="O1")

    writer = SummaryWriter('tensorboard/{}/{}'.format(dataset, save_folder))

    # Model Transfer
    if device != "cpu":
        m = torch.nn.DataParallel(m).cuda()
        criterion = torch.nn.MSELoss().cuda()
    else:
        m = torch.nn.DataParallel(m)
        criterion = torch.nn.MSELoss()

    rnd_inps = torch.random([2, 3, 224, 224])
    writer.add_graph(m, rnd_inps)

    # Start Training
    for i in range(config.epochs)[begin_epoch:]:
        os.makedirs("log/{}".format(dataset), exist_ok=True)
        log = open("log/{}/{}.txt".format(dataset, save_folder), "a+")
        print('############# Starting Epoch {} #############'.format(i))
        log.write('############# Starting Epoch {} #############\n'.format(i))

        for name, param in m.named_parameters():
            writer.add_histogram(name, param.clone().data.to("cpu").numpy(), i)

        loss, acc = train(train_loader, m, criterion, optimizer, writer)

        print('Train-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format(
            idx=i, loss=loss, acc=acc))
        log.write(
            'Train-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}\n'.format(
                idx=i, loss=loss, acc=acc))

        opt.acc = acc
        opt.loss = loss
        m_dev = m.module

        loss, acc = valid(val_loader, m, criterion, optimizer, writer)

        print('Valid:-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format(
            idx=i, loss=loss, acc=acc))
        log.write(
            'Valid:-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}\n'.format(
                idx=i, loss=loss, acc=acc))
        log.close()

        if i % config.save_interval == 0:
            torch.save(
                m_dev.state_dict(),
                'exp/{}/{}/model_{}.pkl'.format(dataset, save_folder, i))
            torch.save(opt,
                       'exp/{}/{}/option.pkl'.format(dataset, save_folder, i))
            torch.save(optimizer,
                       'exp/{}/{}/optimizer.pkl'.format(dataset, save_folder))

    writer.close()