Exemple #1
0
def test_create_model_imagenet():
    model = create_model(False, 'imagenet', 'alexnet')
    model = create_model(False, 'imagenet', 'resnet50')
    model = create_model(True, 'imagenet', 'resnet50')

    with pytest.raises(ValueError):
        model = create_model(False, 'imagenet', 'no_such_model!')
Exemple #2
0
def test_load_gpu_model_on_cpu_with_thinning():
    # Issue #148
    # 1. create a GPU model and remove 50% of the filters in one of the layers (thninning)
    # 2. save the thinned model in a checkpoint file
    # 3. load the checkpoint and place it on the CPU
    CPU_DEVICE_ID = -1
    gpu_model = create_model(False, 'cifar10', 'resnet20_cifar')
    conv_pname = "module.layer1.0.conv1.weight"
    conv_p = distiller.model_find_param(gpu_model, conv_pname)
    pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters",
                                                                desired_sparsity=0.5, weights=conv_pname)
    zeros_mask_dict = distiller.create_model_masks_dict(gpu_model)
    pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None)

    # Use the mask to prune
    zeros_mask_dict[conv_pname].apply_mask(conv_p)
    distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
    assert hasattr(gpu_model, 'thinning_recipes')
    scheduler = distiller.CompressionScheduler(gpu_model)
    save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None)

    CPU_DEVICE_ID = -1
    cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
    load_checkpoint(cpu_model, "checkpoint.pth.tar")
    assert distiller.model_device(cpu_model) == 'cpu'
Exemple #3
0
def test_create_model_cifar():
    pretrained = False
    model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
    with pytest.raises(ValueError):
        # only cifar _10_ is currently supported
        model = create_model(pretrained, 'cifar100', 'resnet20_cifar')
    with pytest.raises(ValueError):
        model = create_model(pretrained, 'cifar10', 'no_such_model!')

    pretrained = True
    with pytest.raises(ValueError):
        # no pretrained models of cifar10
        model = create_model(pretrained, 'cifar10', 'resnet20_cifar')
def name_test(dataset, arch):
    model = create_model(False, dataset, arch, parallel=False)
    modelp = create_model(False, dataset, arch, parallel=True)
    assert model is not None and modelp is not None

    mod_names   = [mod_name for mod_name, _ in model.named_modules()]
    mod_names_p = [mod_name for mod_name, _ in modelp.named_modules()]
    assert mod_names is not None and mod_names_p is not None
    assert len(mod_names)+1 == len(mod_names_p)

    for i in range(len(mod_names)-1):
        assert mod_names[i+1] == normalize_module_name(mod_names_p[i+2])
        logging.debug("{} {} {}".format(mod_names_p[i+2], mod_names[i+1], normalize_module_name(mod_names_p[i+2])))
        assert mod_names_p[i+2] == denormalize_module_name(modelp, mod_names[i+1])
Exemple #5
0
def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
    model = create_model(False, dataset, arch, parallel=dataparallel)
    sgraph = SummaryGraph(model, get_input(dataset))
    sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
    for layer_name in sgraph_layer_names:
        assert (sgraph.find_op(layer_name) is not None,
            '{} was not found in summary graph'.format(layer_name))
Exemple #6
0
def test_load_checkpoint_without_model():
    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
    # Load a checkpoint w/o specifying the model: this should fail because the loaded
    # checkpoint is old and does not have the required metadata to create a model.
    with pytest.raises(ValueError):
        load_checkpoint(model=None, chkpt_file=checkpoint_filename)

    for model_device in (None, 'cuda', 'cpu'):
        # Now we create a new model, save a checkpoint, and load it w/o specifying the model.
        # This should succeed because the checkpoint has enough metadata to create model.
        model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
            model, checkpoint_filename)
        save_checkpoint(epoch=0,
                        arch='resnet20_cifar',
                        model=model,
                        name='eraseme',
                        scheduler=compression_scheduler,
                        optimizer=None,
                        dir='checkpoints')
        temp_checkpoint = os.path.join("checkpoints",
                                       "eraseme_checkpoint.pth.tar")
        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
            model=None, chkpt_file=temp_checkpoint, model_device=model_device)
        assert compression_scheduler is not None
        assert optimizer is None
        assert start_epoch == 1
        assert model
        assert model.arch == "resnet20_cifar"
        assert model.dataset == "cifar10"
        os.remove(temp_checkpoint)
Exemple #7
0
def test_policy_scheduling():
    model = create_model(False, 'cifar10', 'resnet20_cifar')
    scheduler = distiller.CompressionScheduler(model)
    policy = distiller.PruningPolicy(None, None)
    with pytest.raises(AssertionError):
        scheduler.add_policy(policy)
    with pytest.raises(AssertionError):
        # Test for mutual-exclusive configuration
        scheduler.add_policy(policy,
                             epochs=[1, 2, 3],
                             starting_epoch=4,
                             ending_epoch=5,
                             frequency=1)

    scheduler.add_policy(policy,
                         epochs=None,
                         starting_epoch=4,
                         ending_epoch=5,
                         frequency=1)
    # Regression test for issue #176 - https://github.com/NervanaSystems/distiller/issues/176
    scheduler.add_policy(policy, epochs=[1, 2, 3])
    sched_metadata = scheduler.sched_metadata[policy]
    assert sched_metadata['starting_epoch'] == 1
    assert sched_metadata['ending_epoch'] == 4
    assert sched_metadata['frequency'] is None

    scheduler.add_policy(policy, epochs=[5])
    sched_metadata = scheduler.sched_metadata[policy]
    assert sched_metadata['starting_epoch'] == 5
    assert sched_metadata['ending_epoch'] == 6
    assert sched_metadata['frequency'] is None
def create_graph(dataset, arch):
    dummy_input = get_input(dataset)
    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)

    model = create_model(False, dataset, arch, parallel=False)
    assert model is not None
    return SummaryGraph(model, dummy_input)
Exemple #9
0
def test_load():
    logger = logging.getLogger('simple_example')
    logger.setLevel(logging.INFO)

    model = create_model(False, 'cifar10', 'resnet20_cifar')
    model, compression_scheduler, start_epoch = load_checkpoint(model, '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
    assert compression_scheduler is not None
    assert start_epoch == 180
def test_named_params_layers(dataset, arch, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
    sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
    for layer_name in sgraph_layer_names:
        assert sgraph.find_op(
            layer_name
        ) is not None, '{} was not found in summary graph'.format(layer_name)
def _init_learner(args):
    # Create the model
    model = create_model(args.pretrained,
                         args.dataset,
                         args.arch,
                         parallel=not args.load_serialized,
                         device_ids=args.gpus)
    compression_scheduler = None

    # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1
    if args.deprecated_resume:
        msglogger.warning(
            'The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.'
        )
        if not args.reset_optimizer:
            msglogger.warning(
                'If you wish to also reset the optimizer, call with: --reset-optimizer'
            )
            args.reset_optimizer = True
        args.resumed_checkpoint_path = args.deprecated_resume

    optimizer = None
    start_epoch = 0
    if args.resumed_checkpoint_path:
        model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, args.resumed_checkpoint_path, model_device=args.device)
    elif args.load_model_path:
        model = apputils.load_lean_checkpoint(model,
                                              args.load_model_path,
                                              model_device=args.device)
    if args.reset_optimizer:
        start_epoch = 0
        if optimizer is not None:
            optimizer = None
            msglogger.info(
                '\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0'
            )

    if optimizer is None:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        msglogger.debug('Optimizer Type: %s', type(optimizer))
        msglogger.debug('Optimizer Args: %s', optimizer.defaults)

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        compression_scheduler = distiller.file_config(
            model, optimizer, args.compress, compression_scheduler,
            (start_epoch - 1) if args.resumed_checkpoint_path else None)
        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
        model.to(args.device)
    elif compression_scheduler is None:
        compression_scheduler = distiller.CompressionScheduler(model)

    return model, compression_scheduler, optimizer, start_epoch, args.epochs
Exemple #12
0
def test_load_gpu_model_on_cpu():
    # Issue #148
    CPU_DEVICE_ID = -1
    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
    model, compression_scheduler, start_epoch = load_checkpoint(model,
                                                                '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
    assert compression_scheduler is not None
    assert start_epoch == 180
    assert distiller.model_device(model) == 'cpu'
Exemple #13
0
def test_load_dumb_checkpoint():
    # prepare lean checkpoint
    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')

    with tempfile.NamedTemporaryFile() as tmpfile:
        torch.save(state_dict_arrays, tmpfile.name)
        model = create_model(False, 'cifar10', 'resnet20_cifar')
        with pytest.raises(ValueError):
            model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint(model, tmpfile.name)
Exemple #14
0
def test_load_gpu_model_on_cpu_lean_checkpoint():
    CPU_DEVICE_ID = -1
    CPU_DEVICE_NAME = 'cpu'
    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'

    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
    model = load_lean_checkpoint(model, checkpoint_filename,
                                 model_device=CPU_DEVICE_NAME)
    assert distiller.model_device(model) == CPU_DEVICE_NAME
def test_weights_size_attr(dataset, arch, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))

    distiller.assign_layer_fq_names(model)
    for name, mod in model.named_modules():
        if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
            op = sgraph.find_op(name)
            assert op is not None
            assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
Exemple #16
0
def setup_test(arch, dataset, parallel):
    model = create_model(False, dataset, arch, parallel=parallel)
    assert model is not None

    # Create the masks
    zeros_mask_dict = {}
    for name, param in model.named_parameters():
        masker = distiller.ParameterMasker(name)
        zeros_mask_dict[name] = masker
    return model, zeros_mask_dict
Exemple #17
0
def test_load_gpu_model_on_cpu_lean_checkpoint():
    CPU_DEVICE_ID = -1
    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'

    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
    model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint(
        model, checkpoint_filename, lean_checkpoint=True)

    assert compression_scheduler is None
    assert optimizer is None
    assert distiller.model_device(model) == 'cpu'
Exemple #18
0
def test_load_lean_checkpoint_2():
    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'

    model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
    model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint(
        model, checkpoint_filename,
        torch.optim.SGD(model.parameters(), lr=0.36787944117),
        lean_checkpoint=True)

    assert compression_scheduler is None
    assert optimizer is None
    assert start_epoch == 0
Exemple #19
0
def create_graph(dataset, arch):
    if dataset == 'imagenet':
        dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
    elif dataset == 'cifar10':
        dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(
        dataset)

    model = create_model(False, dataset, arch, parallel=False)
    assert model is not None
    dummy_input = dummy_input.to(distiller.model_device(model))
    return SummaryGraph(model, dummy_input)
Exemple #20
0
def test_load_state_dict():
    # prepare lean checkpoint
    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')

    with tempfile.NamedTemporaryFile() as tmpfile:
        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
        model = create_model(False, 'cifar10', 'resnet20_cifar')
        model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)

    assert len(list(model.named_modules())) >= len([x for x in state_dict_arrays if x.endswith('weight')]) > 0
    assert compression_scheduler is None
    assert start_epoch == 0
Exemple #21
0
def test_load_state_dict_implicit():
    # prepare lean checkpoint
    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')

    with tempfile.NamedTemporaryFile() as tmpfile:
        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
        model = create_model(False, 'cifar10', 'resnet20_cifar')
        model, compression_scheduler, optimizer, start_epoch = load_checkpoint(model, tmpfile.name)

    assert compression_scheduler is None
    assert optimizer is None
    assert start_epoch == 0
Exemple #22
0
def test_load_lean_checkpoint_1():
    # prepare lean checkpoint
    state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')

    with tempfile.NamedTemporaryFile() as tmpfile:
        torch.save({'state_dict': state_dict_arrays}, tmpfile.name)
        model = create_model(False, 'cifar10', 'resnet20_cifar')
        model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint(
            model, tmpfile.name, torch.optim.SGD(model.parameters(), lr=0.36787944117),
            lean_checkpoint=True)

    assert compression_scheduler is None
    assert optimizer is None
    assert start_epoch == 0
Exemple #23
0
def test_load_gpu_model_on_cpu():
    # Issue #148
    CPU_DEVICE_ID = -1
    CPU_DEVICE_NAME = 'cpu'
    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'

    model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
    model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
        model, checkpoint_filename)

    assert compression_scheduler is not None
    assert optimizer is not None
    assert distiller.utils.optimizer_device_name(optimizer) == CPU_DEVICE_NAME
    assert start_epoch == 1
    assert distiller.model_device(model) == CPU_DEVICE_NAME
Exemple #24
0
def test_load():
    logger = logging.getLogger('simple_example')
    logger.setLevel(logging.INFO)

    checkpoint_filename = 'checkpoints/resnet20_cifar10_checkpoint.pth.tar'
    src_optimizer_state_dict = torch.load(checkpoint_filename)['optimizer_state_dict']

    model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
    model, compression_scheduler, optimizer, start_epoch = load_checkpoint(
        model, checkpoint_filename)
    assert compression_scheduler is not None
    assert optimizer is not None, 'Failed to load the optimizer'
    if not _is_similar_param_groups(src_optimizer_state_dict, optimizer.state_dict()):
        assert src_optimizer_state_dict == optimizer.state_dict() # this will always fail
    assert start_epoch == 1
Exemple #25
0
def init_knowledge_distillation(args, model, compression_scheduler):
    args.kd_policy = None
    if args.kd_teacher:
        teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
        if args.kd_resume:
            teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume)
        dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)
        args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
        compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,
                                         frequency=1)
        msglogger.info('\nStudent-Teacher knowledge distillation enabled:')
        msglogger.info('\tTeacher Model: %s', args.kd_teacher)
        msglogger.info('\tTemperature: %s', args.kd_temp)
        msglogger.info('\tLoss Weights (distillation | student | teacher): %s',
                       ' | '.join(['{:.2f}'.format(val) for val in dlw]))
        msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch)
Exemple #26
0
def test_load():
    logger = logging.getLogger('simple_example')
    logger.setLevel(logging.INFO)

    checkpoint_filename = '../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar'
    src_optimizer = torch.load(checkpoint_filename)['optimizer']

    model = create_model(False, 'cifar10', 'resnet20_cifar', 0)
    model, compression_scheduler, optimizer, start_epoch, train_steps = load_checkpoint(
        model, checkpoint_filename,
        torch.optim.SGD(model.parameters(), lr=0.36787944117))
    assert compression_scheduler is not None
    assert optimizer is not None, 'Failed to load the optimizer'
    if not _is_similar_param_groups(src_optimizer, optimizer.state_dict()):
        assert src_optimizer == optimizer.state_dict() # this will always fail
    assert start_epoch == 180
    assert train_steps == None # the field isn't present in current checkpoint
Exemple #27
0
def test_utils():
    model = models.create_model(False, 'cifar10', 'resnet20_cifar', parallel=False)
    assert model is not None

    p = distiller.model_find_param(model, "")
    assert p is None

    # Search for a parameter by its "non-parallel" name
    p = distiller.model_find_param(model, "layer1.0.conv1.weight")
    assert p is not None

    # Search for a module name
    module_to_find = None
    for name, m in model.named_modules():
        if name == "layer1.0.conv1":
            module_to_find = m
            break
    assert module_to_find is not None

    module_name = distiller.model_find_module_name(model, module_to_find)
    assert module_name == "layer1.0.conv1"
Exemple #28
0
def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
    """Test removal of arbitrary channels.
    The test receives a specification of channels to remove.
    Based on this specification, the channels are pruned and then physically
    removed from the model (via a "thinning" process).
    """
    model, zeros_mask_dict = common.setup_test(config.arch, config.dataset, is_parallel)

    pair = config.module_pairs[0]
    conv2 = common.find_module_by_name(model, pair[1])
    assert conv2 is not None

    # Test that we can access the weights tensor of the first convolution in layer 1
    conv2_p = distiller.model_find_param(model, pair[1] + ".weight")
    assert conv2_p is not None

    assert conv2_p.dim() == 4
    num_channels = conv2_p.size(1)
    cnt_nnz_channels = num_channels - len(channels_to_remove)
    mask = create_channels_mask(conv2_p, channels_to_remove)
    assert distiller.density_ch(mask) == (conv2.in_channels - len(channels_to_remove)) / conv2.in_channels
    # Cool, so now we have a mask for pruning our channels.

    # Use the mask to prune
    zeros_mask_dict[pair[1] + ".weight"].mask = mask
    zeros_mask_dict[pair[1] + ".weight"].apply_mask(conv2_p)
    all_channels = set([ch for ch in range(num_channels)])
    nnz_channels = set(distiller.find_nonzero_channels_list(conv2_p, pair[1] + ".weight"))
    channels_removed = all_channels - nnz_channels
    logger.info("Channels removed {}".format(channels_removed))

    # Now, let's do the actual network thinning
    distiller.remove_channels(model, zeros_mask_dict, config.arch, config.dataset, optimizer=None)
    conv1 = common.find_module_by_name(model, pair[0])
    assert conv1
    assert conv1.out_channels == cnt_nnz_channels
    assert conv2.in_channels == cnt_nnz_channels
    assert conv1.weight.size(0) == cnt_nnz_channels
    assert conv2.weight.size(1) == cnt_nnz_channels
    if config.bn_name is not None:
        bn1 = common.find_module_by_name(model, config.bn_name)
        assert bn1.running_var.size(0) == cnt_nnz_channels
        assert bn1.running_mean.size(0) == cnt_nnz_channels
        assert bn1.num_features == cnt_nnz_channels
        assert bn1.bias.size(0) == cnt_nnz_channels
        assert bn1.weight.size(0) == cnt_nnz_channels

    dummy_input = common.get_dummy_input(config.dataset)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
    run_forward_backward(model, optimizer, dummy_input)

    # Let's test saving and loading a thinned model.
    # We save 3 times, and load twice, to make sure to cover some corner cases:
    #   - Make sure that after loading, the model still has hold of the thinning recipes
    #   - Make sure that after a 2nd load, there no problem loading (in this case, the
    #   - tensors are already thin, so this is a new flow)
    # (1)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None)
    model_2 = create_model(False, config.dataset, config.arch, parallel=is_parallel)
    model(dummy_input)
    model_2(dummy_input)
    conv2 = common.find_module_by_name(model_2, pair[1])
    assert conv2 is not None
    with pytest.raises(KeyError):
        model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    compression_scheduler = distiller.CompressionScheduler(model)
    hasattr(model, 'thinning_recipes')

    run_forward_backward(model, optimizer, dummy_input)

    # (2)
    save_checkpoint(epoch=0, arch=config.arch, model=model, optimizer=None, scheduler=compression_scheduler)
    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done")

    # (3)
    save_checkpoint(epoch=0, arch=config.arch, model=model_2, optimizer=None, scheduler=compression_scheduler)
    model_2 = load_lean_checkpoint(model_2, 'checkpoint.pth.tar')
    assert hasattr(model_2, 'thinning_recipes')
    logger.info("test_arbitrary_channel_pruning - Done 2")
Exemple #29
0
def main():
    script_dir = os.path.dirname(__file__)
    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
    global msglogger

    # Parse arguments
    args = parser.get_parser().parse_args()
    if args.epochs is None:
        args.epochs = 90

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

    # Log various details about the execution environment.  It is sometimes useful
    # to refer to past experiment executions and this information may be useful.
    apputils.log_execution_env_state(args.compress, msglogger.logdir, gitroot=module_path)
    msglogger.debug("Distiller: %s", distiller.__version__)

    start_epoch = 0
    ending_epoch = args.epochs
    perf_scores_history = []

    if args.evaluate:
        args.deterministic = True
    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        distiller.set_deterministic()  # Use a well-known seed, for repeatability of experiments
    else:
        # Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image
        # classification models, as the input sizes don't change during the run
        # See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3
        cudnn.benchmark = True

    if args.cpu or not torch.cuda.is_available():
        # Set GPU index to -1 if using CPU
        args.device = 'cpu'
        args.gpus = -1
    else:
        args.device = 'cuda'
        if args.gpus is not None:
            try:
                args.gpus = [int(s) for s in args.gpus.split(',')]
            except ValueError:
                raise ValueError('ERROR: Argument --gpus must be a comma-separated list of integers only')
            available_gpus = torch.cuda.device_count()
            for dev_id in args.gpus:
                if dev_id >= available_gpus:
                    raise ValueError('ERROR: GPU device ID {0} requested, but only {1} devices available'
                                     .format(dev_id, available_gpus))
            # Set default device in case the first one on the list != 0
            torch.cuda.set_device(args.gpus[0])

    # Infer the dataset from the model name
    args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
    args.num_classes = 10 if args.dataset == 'cifar10' else 1000

    if args.earlyexit_thresholds:
        args.num_exits = len(args.earlyexit_thresholds) + 1
        args.loss_exits = [0] * args.num_exits
        args.losses_exits = []
        args.exiterrors = []

    # Create the model
    model = create_model(args.pretrained, args.dataset, args.arch,
                         parallel=not args.load_serialized, device_ids=args.gpus)
    compression_scheduler = None
    # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
    # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
    tflogger = TensorBoardLogger(msglogger.logdir)
    pylogger = PythonLogger(msglogger)

    # capture thresholds for early-exit training
    if args.earlyexit_thresholds:
        msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds)

    # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1
    if args.deprecated_resume:
        msglogger.warning('The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.')
        if not args.reset_optimizer:
            msglogger.warning('If you wish to also reset the optimizer, call with: --reset-optimizer')
            args.reset_optimizer = True
        args.resumed_checkpoint_path = args.deprecated_resume

    # We can optionally resume from a checkpoint
    optimizer = None
    if args.resumed_checkpoint_path:
        model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, args.resumed_checkpoint_path, model_device=args.device)
    elif args.load_model_path:
        model = apputils.load_lean_checkpoint(model, args.load_model_path,
                                              model_device=args.device)
    if args.reset_optimizer:
        start_epoch = 0
        if optimizer is not None:
            optimizer = None
            msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0')

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(args.device)

    if optimizer is None:
        optimizer = torch.optim.SGD(model.parameters(),
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        msglogger.info('Optimizer Type: %s', type(optimizer))
        msglogger.info('Optimizer Args: %s', optimizer.defaults)

    if args.AMC:
        return automated_deep_compression(model, criterion, optimizer, pylogger, args)
    if args.greedy:
        return greedy(model, criterion, optimizer, pylogger, args)

    # This sample application can be invoked to produce various summary reports.
    if args.summary:
        return summarize_model(model, args.dataset, which_summary=args.summary)

    activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)

    if args.qe_calibration:
        msglogger.info('Quantization calibration stats collection enabled:')
        msglogger.info('\tStats will be collected for {:.1%} of test dataset'.format(args.qe_calibration))
        msglogger.info('\tSetting constant seeds and converting model to serialized execution')
        distiller.set_deterministic()
        model = distiller.make_non_parallel_copy(model)
        activations_collectors.update(create_quantization_stats_collector(model))
        args.evaluate = True
        args.effective_test_size = args.qe_calibration

    # Load the datasets: the dataset to load is inferred from the model name passed
    # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
    # substring "_cifar", then cifar10 is used.
    train_loader, val_loader, test_loader, _ = apputils.load_data(
        args.dataset, os.path.expanduser(args.data), args.batch_size,
        args.workers, args.validation_split, args.deterministic,
        args.effective_train_size, args.effective_valid_size, args.effective_test_size)
    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                   len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))

    if args.sensitivity is not None:
        sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
        return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)

    if args.evaluate:
        return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args,
                              compression_scheduler)

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler,
            (start_epoch-1) if args.resumed_checkpoint_path else None)
        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
        model.to(args.device)
    elif compression_scheduler is None:
        compression_scheduler = distiller.CompressionScheduler(model)

    if args.thinnify:
        #zeros_mask_dict = distiller.create_model_masks_dict(model)
        assert args.resumed_checkpoint_path is not None, \
            "You must use --resume-from to provide a checkpoint file to thinnify"
        distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None)
        apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler,
                                 name="{}_thinned".format(args.resumed_checkpoint_path.replace(".pth.tar", "")),
                                 dir=msglogger.logdir)
        print("Note: your model may have collapsed to random inference, so you may want to fine-tune")
        return

    args.kd_policy = None
    if args.kd_teacher:
        teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
        if args.kd_resume:
            teacher = apputils.load_lean_checkpoint(teacher, args.kd_resume)
        dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)
        args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
        compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,
                                         frequency=1)

        msglogger.info('\nStudent-Teacher knowledge distillation enabled:')
        msglogger.info('\tTeacher Model: %s', args.kd_teacher)
        msglogger.info('\tTemperature: %s', args.kd_temp)
        msglogger.info('\tLoss Weights (distillation | student | teacher): %s',
                       ' | '.join(['{:.2f}'.format(val) for val in dlw]))
        msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch)

    if start_epoch >= ending_epoch:
        msglogger.error(
            'epoch count is too low, starting epoch is {} but total epochs set to {}'.format(
            start_epoch, ending_epoch))
        raise ValueError('Epochs parameter is too low. Nothing to do.')
    for epoch in range(start_epoch, ending_epoch):
        # This is the main training loop.
        msglogger.info('\n')
        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch,
                metrics=(vloss if (epoch != start_epoch) else 10**6))

        # Train for one epoch
        with collectors_context(activations_collectors["train"]) as collectors:
            train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
                  loggers=[tflogger, pylogger], args=args)
            distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
            distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
                                                collector=collectors["sparsity"])
            if args.masks_sparsity:
                msglogger.info(distiller.masks_sparsity_tbl_summary(model, compression_scheduler))

        # evaluate on validation set
        with collectors_context(activations_collectors["valid"]) as collectors:
            top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
            distiller.log_activation_statsitics(epoch, "valid", loggers=[tflogger],
                                                collector=collectors["sparsity"])
            save_collectors_data(collectors, msglogger.logdir)

        stats = ('Performance/Validation/',
                 OrderedDict([('Loss', vloss),
                              ('Top1', top1),
                              ('Top5', top5)]))
        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1,
                                        loggers=[tflogger])

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch, optimizer)

        # Update the list of top scores achieved so far, and save the checkpoint
        update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores)
        is_best = epoch == perf_scores_history[0].epoch
        checkpoint_extras = {'current_top1': top1,
                             'best_top1': perf_scores_history[0].top1,
                             'best_epoch': perf_scores_history[0].epoch}
        apputils.save_checkpoint(epoch, args.arch, model, optimizer=optimizer, scheduler=compression_scheduler,
                                 extras=checkpoint_extras, is_best=is_best, name=args.name, dir=msglogger.logdir)

    # Finally run results on the test set
    test(test_loader, model, criterion, [pylogger], activations_collectors, args=args)
def create_graph(dataset, arch, parallel=False):
    dummy_input = distiller.get_dummy_input(dataset)
    model = create_model(False, dataset, arch, parallel)
    assert model is not None
    return SummaryGraph(model, dummy_input)