コード例 #1
0
def main():
    args = create_prune_argparser()
    config = create_config(args)

    # Initialize
    init_seeds(seed=0)

    model = Darknet(cfg=config['cfg'], arc=config['arc'])
    mask = create_mask(model)
    bckp = create_backup(model)
    device = select_device(config['device'])

    model = model.to(device)
    # print('Making forwards by 100 iterations')
    # mask = mask.to(device)
    # x = torch.Tensor(10, 3, 416, 416).to(device)
    # for i in range(100):
    #     out = model(x)
    # exit()

    data_dict = parse_data_cfg(config['data'])
    train_path = data_dict['train']

    dataset = LoadImagesAndLabels(
        path=train_path,
        img_size=config['img_size'][0],
        batch_size=config['batch_size'],
        augment=True,
        hyp=config['hyp'],
        cache_images=config['cache_images'],
    )

    # Dataloader
    nw = min([os.cpu_count(), 18 if 18 > 1 else 0, 8])  # number of workers
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=18,
                                             num_workers=nw,
                                             pin_memory=True,
                                             collate_fn=dataset.collate_fn)

    # torch.cuda.empty_cache()

    imgs, _, _, _ = next(iter(dataloader))
    imgs = imgs.float() / 255.0
    imgs = imgs.to(device)

    start = datetime.datetime.now()
    print(f'Starting to compute the time at {start}')
    for i in range(10):
        prune_on_cpu(model, mask, bckp, imgs, config, device)
    end = datetime.datetime.now()
    print(f'Ending at {end}')
    result = end - start
    print(f'Time of {result}')
コード例 #2
0
ファイル: my_kd_gan.py プロジェクト: van-hub/yolo_compression
            # os.system('gsutil cp %s gs://%s/weights' % (config['sub_working_dir'] + fbest, config['bucket']))

    if not config['evolve']:
        plot_results(folder=config['sub_working_dir'])

    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1,
                                                    (time.time() - t0) / 3600))
    dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()

    return results


if __name__ == '__main__':
    args = create_kd_argparser()
    config = create_config(args)
    print("sub working dir: %s" % config['sub_working_dir'])

    # Saving configurations
    import json
    with open(config['sub_working_dir'] + 'config.json', 'w') as f:
        json.dump(config, f)
    f.close()

    config['last'] = config['sub_working_dir'] + 'last.pt'
    config['best_gan'] = config['sub_working_dir'] + 'best_gan.pt'
    config['best'] = config['sub_working_dir'] + 'best.pt'
    config['results_file'] = config['sub_working_dir'] + 'results.txt'

    print(config)