def main():
    """Creates a command line parser"""
    parser = argparse.ArgumentParser(description='Training Landmarks detector in PyTorch')
    parser.add_argument('--train_data_root', dest='train', required=True, type=str, help='Path to train data.')
    parser.add_argument('--train_list', dest='t_list', required=False, type=str, help='Path to train data image list.')
    parser.add_argument('--train_landmarks', default='', dest='t_land', required=False, type=str,
                        help='Path to landmarks for the train images.')
    parser.add_argument('--train_batch_size', type=int, default=170, help='Train batch size.')
    parser.add_argument('--epoch_total_num', type=int, default=30, help='Number of epochs to train.')
    parser.add_argument('--lr', type=float, default=0.4, help='Learning rate.')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
    parser.add_argument('--val_step', type=int, default=2000, help='Evaluate model each val_step during each epoch.')
    parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay.')
    parser.add_argument('--device', '-d', default=0, type=int)
    parser.add_argument('--snap_folder', type=str, default='./snapshots/', help='Folder to save snapshots.')
    parser.add_argument('--snap_prefix', type=str, default='LandmarksNet', help='Prefix for snapshots.')
    parser.add_argument('--snap_to_resume', type=str, default=None, help='Snapshot to resume.')
    parser.add_argument('--dataset', choices=['vgg', 'celeb', 'ngd'], type=str, default='vgg', help='Dataset.')
    parser.add_argument('-c', '--compr_config', help='Path to a file with compression parameters', required=False)
    parser.add_argument('--to-onnx', type=str, metavar='PATH', default=None, help='Export to ONNX model by given path')
    arguments = parser.parse_args()

    if args.compr_config:
        patch_torch_operators()

    with torch.cuda.device(arguments.device):
        train(arguments)
Exemplo n.º 2
0
    def test_number_of_nodes_for_module_in_loop__not_input_node(self):
        num_iter = 5
        patch_torch_operators()

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def forward(self, x):
                    s = F.sigmoid(x)
                    t = F.tanh(x)
                    result = F.sigmoid(x) * t + F.tanh(x) * s
                    return result

                @staticmethod
                def nodes_number():
                    return 7

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(F.relu(x))
                return x

            def nodes_number(self):
                return self.inner.nodes_number() + num_iter

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Exemplo n.º 3
0
def test_export_stacked_bi_lstm(tmp_path):
    p = LSTMTestSizes(3, 3, 3, 3)
    patch_torch_operators()
    config = get_empty_config(input_sample_size=(1, p.hidden_size,
                                                 p.input_size))
    config['compression'] = {'algorithm': 'quantization'}

    config.log_dir = str(tmp_path)
    reset_context('orig')
    reset_context('quantized_graphs')
    # TODO: batch_first=True fails with building graph: ambiguous call to mul or sigmoid
    test_rnn = NNCF_RNN('LSTM',
                        input_size=p.input_size,
                        hidden_size=p.hidden_size,
                        num_layers=2,
                        bidirectional=True,
                        batch_first=False)
    algo, model = create_compressed_model(test_rnn, config)

    test_path = str(tmp_path.joinpath('test.onnx'))
    algo.export_model(test_path)
    assert os.path.exists(test_path)

    onnx_num = 0
    model = onnx.load(test_path)
    # pylint: disable=no-member
    for node in model.graph.node:
        if node.op_type == 'FakeQuantize':
            onnx_num += 1
    assert onnx_num == 54
Exemplo n.º 4
0
    def test_number_of_nodes_for_module_with_nested_loops(self):
        num_iter = 5
        patch_torch_operators()

        class TestIterModule(nn.Module):
            @ITERATION_MODULES.register()
            class TestIterModule_ResetPoint(nn.Module):
                def __init__(self, loop_module):
                    super().__init__()
                    self.loop_module = loop_module

                def forward(self, x):
                    return self.loop_module(F.relu(x))

            def __init__(self):
                super().__init__()
                self.loop_module = self.LoopModule2()
                self.reset_point = self.TestIterModule_ResetPoint(
                    self.loop_module)

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.reset_point(x)
                return x

            class LoopModule2(nn.Module):
                @ITERATION_MODULES.register()
                class LoopModule2_ResetPoint(nn.Module):
                    def __init__(self, inner):
                        super().__init__()
                        self.inner = inner

                    def forward(self, x):
                        return self.inner(F.relu(x))

                def __init__(self):
                    super().__init__()
                    self.inner = self.Inner()
                    self.reset_helper = self.LoopModule2_ResetPoint(self.inner)

                def forward(self, x):
                    for _ in range(num_iter):
                        self.reset_helper(x)
                    return x

                class Inner(nn.Module):
                    def forward(self, x):
                        s = F.sigmoid(x)
                        t = F.tanh(x)
                        result = t + s
                        return result

        test_module = TestIterModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == num_iter
def main():
    """Creates a cl parser"""
    parser = argparse.ArgumentParser(
        description='Evaluation script for landmarks detection network')
    parser.add_argument('--device', '-d', default=0, type=int)
    parser.add_argument('--val_data_root',
                        dest='val',
                        required=True,
                        type=str,
                        help='Path to val data.')
    parser.add_argument('--val_list',
                        dest='v_list',
                        required=False,
                        type=str,
                        help='Path to test data image list.')
    parser.add_argument('--val_landmarks',
                        dest='v_land',
                        default='',
                        required=False,
                        type=str,
                        help='Path to landmarks for test images.')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=1,
                        help='Validation batch size.')
    parser.add_argument('--snapshot',
                        type=str,
                        default=None,
                        help='Snapshot to evaluate.')
    parser.add_argument('--dataset',
                        choices=['vgg', 'celeb', 'ngd'],
                        type=str,
                        default='vgg',
                        help='Dataset.')
    parser.add_argument('-c',
                        '--compr_config',
                        help='Path to a file with compression parameters',
                        required=False)
    args = parser.parse_args()

    if args.compr_config:
        patch_torch_operators()

    with torch.cuda.device(args.device):
        start_evaluation(args)
def test_can_restore_binary_mask_on_magnitude_quant_algo_resume(tmp_path):
    patch_torch_operators()
    config = get_empty_config()
    config["compression"] = [{
        "algorithm": "magnitude_sparsity",
        "weight_importance": "abs",
        "params": {
            "schedule": "multistep",
            "sparsity_levels": [0.3, 0.5]
        }
    }, {
        "algorithm": "quantization"
    }]
    config.log_dir = str(tmp_path)
    reset_context('orig')
    reset_context('quantized_graphs')
    _, model = create_compressed_model(MagnitudeTestModel(), config)
    # load_state doesn't support CPU + Quantization
    sparse_model = torch.nn.DataParallel(model)
    sparse_model.cuda()
    with torch.no_grad():
        sparse_model(torch.ones([1, 1, 10, 10]))

    reset_context('orig')
    reset_context('quantized_graphs')
    config = get_empty_config()
    config.log_dir = str(tmp_path)
    config["compression"] = [{
        "algorithm": "const_sparsity"
    }, {
        "algorithm": "quantization"
    }]
    _, const_sparse_model = create_compressed_model(MagnitudeTestModel(),
                                                    config)

    load_state(const_sparse_model, sparse_model.state_dict())

    op = const_sparse_model.get_nncf_wrapped_module().conv1.pre_ops['0']
    check_equal(ref_mask_1, op.operand.binary_mask)

    op = const_sparse_model.get_nncf_wrapped_module().conv2.pre_ops['0']
    check_equal(ref_mask_2, op.operand.binary_mask)
Exemplo n.º 7
0
def test_export_lstm_cell(tmp_path):
    patch_torch_operators()
    config = get_empty_config(model_size=1, input_sample_size=(1, 1))
    config['compression'] = {'algorithm': 'quantization'}

    config.log_dir = str(tmp_path)
    reset_context('orig')
    reset_context('quantized_graphs')
    algo, model = create_compressed_model(LSTMCellNNCF(1, 1), config)

    test_path = str(tmp_path.joinpath('test.onnx'))
    algo.export_model(test_path)
    assert os.path.exists(test_path)

    onnx_num = 0
    model = onnx.load(test_path)
    # pylint: disable=no-member
    for node in model.graph.node:
        if node.op_type == 'FakeQuantize':
            onnx_num += 1
    assert onnx_num == 13
Exemplo n.º 8
0
    def test_number_of_nodes_for_module_in_loop(self):
        num_iter = 5
        patch_torch_operators()

        class LoopModule(nn.Module):
            class Inner(nn.Module):
                def __init__(self):
                    super().__init__()
                    self.operator1 = torch.sigmoid
                    self.operator2 = torch.tanh

                def forward(self, x):
                    s = self.operator1(x)
                    t = self.operator2(x)
                    result = t + s
                    return result

                @staticmethod
                def nodes_number():
                    return 3

            def __init__(self):
                super().__init__()
                self.inner = self.Inner()

            def forward(self, x):
                for _ in range(num_iter):
                    x = self.inner(x)
                return x

            def nodes_number(self):
                return self.inner.nodes_number()

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            _ = test_module(torch.zeros(1))
            assert ctx.graph.get_nodes_count() == test_module.nodes_number()
Exemplo n.º 9
0
    def test_number_of_nodes_for_repeated_module(self):
        patch_torch_operators()

        class LoopModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.operator = F.relu
                self.layers = nn.ModuleList(
                    [nn.Conv2d(1, 1, 1),
                     nn.Conv2d(1, 1, 1)])

            def forward(self, x):
                for layer in self.layers:
                    x = F.relu(layer(x))
                return x

        test_module = LoopModule()
        reset_context('test')
        with context('test') as ctx:
            x = test_module(torch.zeros(1, 1, 1, 1))
            assert ctx.graph.get_nodes_count() == 4
            _ = test_module(x)
            assert ctx.graph.get_nodes_count() == 8
def validate_torch_model(output_dir,
                         config,
                         num_layers,
                         dump,
                         val_loader=None,
                         cuda=False):
    from nncf.dynamic_graph import patch_torch_operators
    from tools.debug.common import load_torch_model, register_print_hooks

    patch_torch_operators()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model = load_torch_model(config, cuda)

    model_e = model.eval()
    if dump:
        register_print_hooks(output_dir,
                             model_e,
                             num_layers=num_layers,
                             data_to_compare=None,
                             dump_activations=True)

    validate_general(val_loader, model_e, infer_pytorch_model, cuda)
def test_model_can_be_loaded_with_resume(_params, tmp_path):
    p = _params
    config_path = p['nncf_config_path']
    checkpoint_path = p['checkpoint_path']

    config = Config.from_json(str(config_path))
    config.execution_mode = p['execution_mode']

    config.current_gpu = 0
    config.log_dir = str(tmp_path)
    config.device = get_device(config)
    config.distributed = config.execution_mode in (ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        config.dist_url = "tcp://127.0.0.1:9898"
        config.dist_backend = "nccl"
        config.rank = 0
        config.world_size = 1
        configure_distributed(config)

    model_name = config['model']
    model = load_model(model_name,
                       pretrained=False,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'))

    patch_torch_operators()
    compression_algo, model = create_compressed_model(model, config)
    model, _ = prepare_model_for_execution(model, config)

    if config.distributed:
        compression_algo.distributed()

    reset_context('orig')
    reset_context('quantized_graphs')
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    load_state(model, checkpoint['state_dict'], is_resume=True)
Exemplo n.º 12
0
from examples.common.models.classification import squeezenet1_1_custom
from nncf import Quantization, SymmetricQuantizer, AsymmetricQuantizer
from nncf import utils
from nncf.algo_selector import create_compression_algorithm
from nncf.compression_method_api import CompressionLoss, CompressionScheduler
from nncf.config import Config
from nncf.dynamic_graph import reset_context, patch_torch_operators
from nncf.dynamic_graph.graph_builder import create_input_infos
from nncf.helpers import safe_thread_call, create_compressed_model, load_state
from nncf.initialization import InitializingDataLoader, INITIALIZABLE_MODULES
from nncf.operations import UpdateWeight, UpdateInputs
from nncf.quantization.layers import QuantizationMode, QuantizerConfig
from nncf.utils import get_all_modules_by_type
from tests.test_helpers import BasicConvTestModel, TwoConvTestModel, get_empty_config

patch_torch_operators()


def get_basic_quantization_config(model_size=4):
    config = Config()
    config.update({
        "model": "basic_quant_conv",
        "model_size": model_size,
        "input_info": {
            "sample_size": (1, 1, model_size, model_size),
        },
        "compression": {
            "algorithm": "quantization",
            "initializer": {
                "num_init_steps": 0
            },
Exemplo n.º 13
0
    def test_number_of_calling_fq_for_gnmt(self, tmp_path):
        torch.cuda.set_device(0)
        device = torch.device('cuda')
        batch_first = False
        vocab_size = 32000
        model_config = {
            'hidden_size': 100,
            'vocab_size': vocab_size,
            'num_layers': 4,
            'dropout': 0.2,
            'batch_first': batch_first,
            'share_embedding': True,
        }
        batch_size = 128
        sequence_size = 50
        input_sample_size = (batch_size,
                             sequence_size) if batch_first else (sequence_size,
                                                                 batch_size)
        patch_torch_operators()
        config = get_empty_config(input_sample_size=input_sample_size)
        config['compression'] = \
            {'algorithm': 'quantization',
             'quantize_inputs': True,
             'quantizable_subgraph_patterns': [["linear", "__add__"],
                                               ["sigmoid", "__mul__", "__add__"],
                                               ["__add__", "tanh", "__mul__"],
                                               ["sigmoid", "__mul__"]],
             'scopes_without_shape_matching':
                 ['GNMT/ResidualRecurrentDecoder[decoder]/RecurrentAttention[att_rnn]/BahdanauAttention[attn]'],
             'disable_function_quantization_hooks': True}

        config.log_dir = str(tmp_path)
        reset_context('orig')
        reset_context('quantized_graphs')

        model = GNMT(**model_config)
        model = replace_lstm(model)
        model.to(device)

        def dummy_forward_fn(model, seq_len=sequence_size):
            def gen_packed_sequence():
                seq_list = []
                seq_lens = torch.LongTensor(batch_size).random_(1, seq_len + 1)
                seq_lens = torch.sort(seq_lens, descending=True).values
                for seq_size in seq_lens:
                    seq_list.append(
                        torch.LongTensor(seq_size.item()).random_(
                            1, vocab_size).to(device))
                padded_seq_batch = torch.nn.utils.rnn.pad_sequence(
                    seq_list, batch_first=batch_first)
                return padded_seq_batch, seq_lens

            x_data, seq_lens = gen_packed_sequence()
            input_encoder = x_data
            input_enc_len = seq_lens.to(device)
            input_decoder = gen_packed_sequence()[0]
            model.forward(input_encoder, input_enc_len, input_decoder)

        _, model = create_compressed_model(model, config, dummy_forward_fn)
        model.to(device)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        for name, quantizer in model.all_quantizations.items():
            counter = Counter()
            counters[name] = counter
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
        with context('quantized_graphs') as ctx:
            dummy_forward_fn(model)
            assert ctx.graph.get_nodes_count() == 239
            assert len(counters) == 68
            for name, counter in counters.items():
                if 'cell' in name or "LSTMCellForwardNNCF" in name:
                    assert counter.count == sequence_size, name
                else:
                    assert counter.count == 1, name
            new_seq_len = int(sequence_size / 2)
            dummy_forward_fn(model, new_seq_len)
            assert ctx.graph.get_nodes_count() == 239
            assert len(counters) == 68
            for name, counter in counters.items():
                if 'cell' in name or "LSTMCellForwardNNCF" in name:
                    assert counter.count == sequence_size + new_seq_len, name
                else:
                    assert counter.count == 2, name
Exemplo n.º 14
0
    def test_number_of_calling_fq_for_lstm(self, tmp_path):
        p = LSTMTestSizes(1, 1, 1, 5)
        num_layers = 2
        bidirectional = True
        num_directions = 2 if bidirectional else 1
        bias = True
        batch_first = False
        patch_torch_operators()
        config = get_empty_config(input_sample_size=(p.seq_length, p.batch,
                                                     p.input_size))
        config['compression'] = {
            'algorithm': 'quantization',
            'quantize_inputs': True
        }

        config.log_dir = str(tmp_path)
        reset_context('orig')
        reset_context('quantized_graphs')
        test_data = TestLSTMCell.generate_lstm_data(p,
                                                    num_layers,
                                                    num_directions,
                                                    bias=bias,
                                                    batch_first=batch_first)

        test_rnn = NNCF_RNN('LSTM',
                            input_size=p.input_size,
                            hidden_size=p.hidden_size,
                            num_layers=num_layers,
                            bidirectional=bidirectional,
                            bias=bias,
                            batch_first=batch_first)
        TestLSTM.set_ref_lstm_weights(test_data, test_rnn, num_layers,
                                      num_directions, bias)
        test_hidden = TestLSTM.get_test_lstm_hidden(test_data)

        _ = reset_context('orig')
        _ = reset_context('quantized_graphs')
        _, model = create_compressed_model(test_rnn, config)

        class Counter:
            def __init__(self):
                self.count = 0

            def next(self):
                self.count += 1

        def hook(model, input_, counter):
            counter.next()

        counters = {}
        for name, quantizer in model.all_quantizations.items():
            counter = Counter()
            counters[name] = counter
            quantizer.register_forward_pre_hook(partial(hook, counter=counter))
        with context('quantized_graphs') as ctx:
            _ = model(test_data.x, test_hidden)
            assert ctx.graph.get_nodes_count() == 110
            ctx.graph.dump_graph(
                os.path.join(config.log_dir, "compressed_graph_next.dot"))
        assert len(counters) == 54
        for counter in counters.values():
            assert counter.count == p.seq_length
def main():
    parser = argparse.ArgumentParser(description='Evaluation script for Face Recognition in PyTorch')
    parser.add_argument('--devices', type=int, nargs='+', default=[0], help='CUDA devices to use.')
    parser.add_argument('--embed_size', type=int, default=128, help='Size of the face embedding.')
    parser.add_argument('--val_data_root', dest='val', required=True, type=str, help='Path to validation data.')
    parser.add_argument('--val_list', dest='v_list', required=True, type=str, help='Path to train data image list.')
    parser.add_argument('--val_landmarks', dest='v_land', default='', required=False, type=str,
                        help='Path to landmarks for the test images.')
    parser.add_argument('--val_batch_size', type=int, default=8, help='Validation batch size.')
    parser.add_argument('--snap', type=str, required=False, help='Snapshot to evaluate.')
    parser.add_argument('--roc_fname', type=str, default='', help='ROC file.')
    parser.add_argument('--dump_embeddings', action='store_true', help='Dump embeddings to summary writer.')
    parser.add_argument('--dist', choices=['l2', 'cos'], type=str, default='cos', help='Distance.')
    parser.add_argument('--flipped_emb', action='store_true', help='Flipped embedding concatenation trick.')
    parser.add_argument('--show_failed', action='store_true', help='Show misclassified pairs.')
    parser.add_argument('--model', choices=models_backbones.keys(), type=str, default='rmnet', help='Model type.')
    parser.add_argument('--engine', choices=['pt', 'ie'], type=str, default='pt', help='Framework to use for eval.')

    # IE-related options
    parser.add_argument('--fr_model', type=str, required=False)
    parser.add_argument('--lm_model', type=str, required=False)
    parser.add_argument('-pp', '--plugin_dir', type=str, default=None, help='Path to a plugin folder')
    parser.add_argument('-c', '--compr_config', help='Path to a file with compression parameters', required=False)
    args = parser.parse_args()

    if args.engine == 'pt':
        assert args.snap is not None, 'To evaluate PyTorch snapshot, please, specify --snap option.'

        if args.compr_config:
            patch_torch_operators()

        with torch.cuda.device(args.devices[0]):
            data, embeddings_fun = load_test_dataset(args)
            model = models_backbones[args.model](embedding_size=args.embed_size, feature=True)

            if args.compr_config:
                config = Config.from_json(args.compr_config)
                compression_algo = create_compression_algorithm(model, config)
                model = compression_algo.model

            model = load_model_state(model, args.snap, args.devices[0])
            evaluate(args, data, model, embeddings_fun, args.val_batch_size, args.dump_embeddings,
                     args.roc_fname, args.snap, True, args.show_failed)

            if args.compr_config and "sparsity_level" in compression_algo.statistics():
                log.info("Sparsity level: {0:.2f}".format(
                    compression_algo.statistics()['sparsity_rate_for_sparsified_modules']))
    else:
        from utils.ie_tools import load_ie_model

        assert args.fr_model is not None, 'To evaluate IE model, please, specify --fr_model option.'
        fr_model = load_ie_model(args.fr_model, 'CPU', args.plugin_dir)
        lm_model = None
        if args.lm_model:
            lm_model = load_ie_model(args.lm_model, 'CPU', args.plugin_dir)
        input_size = tuple(fr_model.get_input_shape()[2:])

        lfw = LFW(args.val, args.v_list, args.v_land)
        if not lfw.use_landmarks or lm_model:
            lfw.transform = t.Compose([ResizeNumpy(220), CenterCropNumpy(input_size)])
            lfw.use_landmarks = False
        else:
            log.info('Using landmarks for the LFW images.')
            lfw.transform = t.Compose([ResizeNumpy(input_size)])

        evaluate(args, lfw, fr_model, partial(compute_embeddings_lfw_ie, lm_model=lm_model), val_batch_size=1,
                 dump_embeddings=False, roc_fname='', snap_name='', verbose=True, show_failed=False)
Exemplo n.º 16
0
def main():
    """Creates a command line parser and starts training"""
    parser = ArgumentParserWithYaml(
        description='Training Face Recognition in PyTorch',
        fromfile_prefix_chars='@',
        epilog="Please, note that you can parse parameters from a yaml file if \
                                    you add @<path_to_yaml_file> to command line"
    )

    #datasets configuration
    parser.add_argument('--train_dataset',
                        choices=['vgg', 'ms1m', 'trp', 'imdbface'],
                        type=str,
                        default='vgg',
                        help='Name of the train dataset.')
    parser.add_argument('--train_data_root',
                        dest='train',
                        required=True,
                        type=str,
                        help='Path to train data.')
    parser.add_argument('--train_list',
                        dest='t_list',
                        required=False,
                        type=str,
                        help='Path to train data image list.')
    parser.add_argument('--train_landmarks',
                        default='',
                        dest='t_land',
                        required=False,
                        type=str,
                        help='Path to landmarks for the train images.')

    parser.add_argument('--val_data_root',
                        dest='val',
                        required=True,
                        type=str,
                        help='Path to val data.')
    parser.add_argument('--val_step',
                        type=int,
                        default=1000,
                        help='Evaluate model each val_step during each epoch.')
    parser.add_argument('--val_list',
                        dest='v_list',
                        required=True,
                        type=str,
                        help='Path to test data image list.')
    parser.add_argument('--val_landmarks',
                        dest='v_land',
                        default='',
                        required=False,
                        type=str,
                        help='Path to landmarks for test images.')

    #model configuration
    parser.add_argument('--model',
                        choices=models_backbones.keys(),
                        type=str,
                        default='mobilenet',
                        help='Model type.')
    parser.add_argument('--embed_size',
                        type=int,
                        default=256,
                        help='Size of the face embedding.')

    #optimizer configuration
    parser.add_argument('--train_batch_size',
                        type=int,
                        default=170,
                        help='Train batch size.')
    parser.add_argument('--epoch_total_num',
                        type=int,
                        default=30,
                        help='Number of epochs to train.')
    parser.add_argument('--lr', type=float, default=0.4, help='Learning rate.')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='Momentum.')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0001,
                        help='Weight decay.')

    #loss configuration
    parser.add_argument('--mining_type',
                        choices=['focal', 'sv'],
                        type=str,
                        default='sv',
                        help='Hard mining method in loss.')
    parser.add_argument(
        '--t',
        type=float,
        default=1.1,
        help=
        't in support vector softmax. See https://arxiv.org/abs/1812.11317 for details'
    )
    parser.add_argument(
        '--gamma',
        type=float,
        default=2.,
        help=
        'Gamma in focal loss. See https://arxiv.org/abs/1708.02002 for details'
    )
    parser.add_argument('--m',
                        type=float,
                        default=0.35,
                        help='Margin size for AMSoftmax.')
    parser.add_argument('--s',
                        type=float,
                        default=30.,
                        help='Scale for AMSoftmax.')
    parser.add_argument('--margin_type',
                        choices=['cos', 'arc'],
                        type=str,
                        default='cos',
                        help='Margin type for AMSoftmax loss.')

    #other parameters
    parser.add_argument('--devices',
                        type=int,
                        nargs='+',
                        default=[0],
                        help='CUDA devices to use.')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=20,
                        help='Validation batch size.')
    parser.add_argument('--snap_folder',
                        type=str,
                        default='./snapshots/',
                        help='Folder to save snapshots.')
    parser.add_argument('--snap_prefix',
                        type=str,
                        default='FaceReidNet',
                        help='Prefix for snapshots.')
    parser.add_argument('--snap_to_resume',
                        type=str,
                        default=None,
                        help='Snapshot to resume.')
    parser.add_argument('--weighted', action='store_true')
    parser.add_argument('-c',
                        '--compr_config',
                        help='Path to a file with compression parameters',
                        required=False)
    parser.add_argument('--to-onnx',
                        type=str,
                        metavar='PATH',
                        default=None,
                        help='Export to ONNX model by given path')

    args = parser.parse_args()
    log.info('Arguments:\n' + pformat(args.__dict__))

    if args.compr_config:
        patch_torch_operators()

    with torch.cuda.device(args.devices[0]):
        train(args)