def prune_finetune_test(iteration, model_cpy, pruning_step, test_fn, train_fn,
                        app_args, param_name, effective_train_size):
    pylogger = PythonLogger(msglogger)
    zeros_mask_dict = distiller.create_model_masks_dict(model_cpy)
    param = model_cpy.state_dict()[param_name]
    if 0 == prune_tensor(param, param_name, pruning_step, zeros_mask_dict):
        return (-1, -1, -1, zeros_mask_dict)  # Did not prune anything

    if train_fn is not None:
        # Fine-tune
        optimizer = torch.optim.SGD(model_cpy.parameters(),
                                    lr=app_args.lr,
                                    momentum=app_args.momentum,
                                    weight_decay=app_args.weight_decay)
        app_args.effective_train_size = effective_train_size
        train_fn(model=model_cpy,
                 compression_scheduler=create_scheduler(
                     model_cpy, zeros_mask_dict),
                 optimizer=optimizer,
                 epoch=iteration,
                 loggers=[pylogger])

    # Physically remove filters
    dataset = app_args.dataset
    arch = app_args.arch
    distiller.remove_filters(model_cpy,
                             zeros_mask_dict,
                             arch,
                             dataset,
                             optimizer=None)

    # Test and record the performance of the pruned model
    prec1, prec5, loss = test_fn(model=model_cpy, loggers=None)
    return (prec1, prec5, loss, zeros_mask_dict)
예제 #2
0
def main():
    global msglogger
    check_pytorch_version()
    args = parser.parse_args()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    msglogger = apputils.config_pylogger(
        os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

    # Log various details about the execution environment.  It is sometimes useful
    # to refer to past experiment executions and this information may be useful.
    apputils.log_execution_env_state(sys.argv, gitroot=module_path)
    msglogger.debug("Distiller: %s", distiller.__version__)

    start_epoch = 0
    best_top1 = 0

    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        # In Pytorch, support for deterministic execution is still a bit clunky.
        if args.workers > 1:
            msglogger.error(
                'ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1'
            )
            exit(1)
        # Use a well-known seed, for repeatability of experiments
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)
        cudnn.deterministic = True
    else:
        # This issue: https://github.com/pytorch/pytorch/issues/3659
        # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
        # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
        cudnn.benchmark = True

    if args.gpus is not None:
        try:
            args.gpus = [int(s) for s in args.gpus.split(',')]
        except ValueError:
            msglogger.error(
                'ERROR: Argument --gpus must be a comma-separated list of integers only'
            )
            exit(1)
        available_gpus = torch.cuda.device_count()
        for dev_id in args.gpus:
            if dev_id >= available_gpus:
                msglogger.error(
                    'ERROR: GPU device ID {0} requested, but only {1} devices available'
                    .format(dev_id, available_gpus))
                exit(1)
        # Set default device in case the first one on the list != 0
        torch.cuda.set_device(args.gpus[0])

    # Infer the dataset from the model name
    args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'

    # Create the model
    png_summary = args.summary is not None and args.summary.startswith('png')
    is_parallel = not png_summary and args.summary != 'compute'  # For PNG summary, parallel graphs are illegible
    model = create_model(args.pretrained,
                         args.dataset,
                         args.arch,
                         parallel=is_parallel,
                         device_ids=args.gpus)

    compression_scheduler = None
    # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
    # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
    tflogger = TensorBoardLogger(msglogger.logdir)
    pylogger = PythonLogger(msglogger)

    # We can optionally resume from a checkpoint
    if args.resume:
        model, compression_scheduler, start_epoch = apputils.load_checkpoint(
            model, chkpt_file=args.resume)

        if 'resnet' in args.arch and 'preact' not in args.arch and 'cifar' in args.arch:
            distiller.resnet_cifar_remove_layers(model)
            #model = distiller.resnet_cifar_remove_channels(model, compression_scheduler.zeros_mask_dict)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    msglogger.info('Optimizer Type: %s', type(optimizer))
    msglogger.info('Optimizer Args: %s', optimizer.defaults)

    # This sample application can be invoked to produce various summary reports.
    if args.summary:
        which_summary = args.summary
        if which_summary.startswith('png'):
            apputils.draw_img_classifier_to_file(
                model, 'model.png', args.dataset,
                which_summary == 'png_w_params')
        else:
            distiller.model_summary(model, which_summary, args.dataset)
        exit()

    # Load the datasets: the dataset to load is inferred from the model name passed
    # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
    # substring "_cifar", then cifar10 is used.
    train_loader, val_loader, test_loader, _ = apputils.load_data(
        args.dataset, os.path.expanduser(args.data), args.batch_size,
        args.workers, args.validation_size, args.deterministic)
    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                   len(train_loader.sampler), len(val_loader.sampler),
                   len(test_loader.sampler))

    activations_sparsity = None
    if args.activation_stats:
        # If your model has ReLU layers, then those layers have sparse activations.
        # ActivationSparsityCollector will collect information about this sparsity.
        # WARNING! Enabling activation sparsity collection will significantly slow down training!
        activations_sparsity = ActivationSparsityCollector(model)

    if args.sensitivity is not None:
        # This sample application can be invoked to execute Sensitivity Analysis on your
        # model.  The ouptut is saved to CSV and PNG.
        msglogger.info("Running sensitivity tests")
        test_fnc = partial(test,
                           test_loader=test_loader,
                           criterion=criterion,
                           loggers=[pylogger],
                           print_freq=args.print_freq)
        which_params = [
            param_name for param_name, _ in model.named_parameters()
        ]
        sensitivity = distiller.perform_sensitivity_analysis(
            model,
            net_params=which_params,
            sparsities=np.arange(0.0, 0.50, 0.05)
            if args.sensitivity == 'filter' else np.arange(0.0, 0.95, 0.05),
            test_func=test_fnc,
            group=args.sensitivity)
        distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
        distiller.sensitivities_to_csv(sensitivity, 'sensitivity.csv')
        exit()

    if args.evaluate:
        # This sample application can be invoked to evaluate the accuracy of your model on
        # the test dataset.
        # You can optionally quantize the model to 8-bit integer before evaluation.
        # For example:
        # python3 compress_classifier.py --arch resnet20_cifar  ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate
        if args.quantize:
            model.cpu()
            quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
            quantizer.prepare_model()
            model.cuda()
        top1, _, _ = test(test_loader, model, criterion, [pylogger],
                          args.print_freq)
        if args.quantize:
            checkpoint_name = 'quantized'
            apputils.save_checkpoint(0,
                                     args.arch,
                                     model,
                                     optimizer=None,
                                     best_top1=top1,
                                     name='_'.split(args.name, checkpoint_name)
                                     if args.name else checkpoint_name,
                                     dir=msglogger.logdir)
        exit()

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        compression_scheduler = distiller.file_config(model, optimizer,
                                                      args.compress)

    for epoch in range(start_epoch, start_epoch + args.epochs):
        # This is the main training loop.
        msglogger.info('\n')
        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch)

        # Train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              compression_scheduler,
              loggers=[tflogger, pylogger],
              print_freq=args.print_freq,
              log_params_hist=args.log_params_histograms)
        distiller.log_weights_sparsity(model,
                                       epoch,
                                       loggers=[tflogger, pylogger])
        if args.activation_stats:
            distiller.log_activation_sparsity(epoch,
                                              loggers=[tflogger, pylogger],
                                              collector=activations_sparsity)

        # evaluate on validation set
        top1, top5, vloss = validate(val_loader, model, criterion, [pylogger],
                                     args.print_freq, epoch)
        stats = ('Peformance/Validation/',
                 OrderedDict([('Loss', vloss), ('Top1', top1),
                              ('Top5', top5)]))
        distiller.log_training_progress(stats,
                                        None,
                                        epoch,
                                        steps_completed=0,
                                        total_steps=1,
                                        log_freq=1,
                                        loggers=[tflogger])

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch)

        # remember best top1 and save checkpoint
        is_best = top1 > best_top1
        best_top1 = max(top1, best_top1)
        apputils.save_checkpoint(epoch, args.arch, model, optimizer,
                                 compression_scheduler, best_top1, is_best,
                                 args.name, msglogger.logdir)

    # Finally run results on the test set
    test(test_loader, model, criterion, [pylogger], args.print_freq)
def main():
    global msglogger
    check_pytorch_version()
    args = parser.parse_args()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)

    # Log various details about the execution environment.  It is sometimes useful
    # to refer to past experiment executions and this information may be useful.
    apputils.log_execution_env_state(sys.argv, gitroot=module_path)
    msglogger.debug("Distiller: %s", distiller.__version__)

    start_epoch = 0
    best_top1 = 0
    best_epoch = 0

    if args.deterministic:
        # Experiment reproducibility is sometimes important.  Pete Warden expounded about this
        # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/
        # In Pytorch, support for deterministic execution is still a bit clunky.
        if args.workers > 1:
            msglogger.error('ERROR: Setting --deterministic requires setting --workers/-j to 0 or 1')
            exit(1)
        # Use a well-known seed, for repeatability of experiments
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)
        cudnn.deterministic = True
    else:
        # This issue: https://github.com/pytorch/pytorch/issues/3659
        # Implies that cudnn.benchmark should respect cudnn.deterministic, but empirically we see that
        # results are not re-produced when benchmark is set. So enabling only if deterministic mode disabled.
        cudnn.benchmark = True

    if args.gpus is not None:
        try:
            args.gpus = [int(s) for s in args.gpus.split(',')]
        except ValueError:
            msglogger.error('ERROR: Argument --gpus must be a comma-separated list of integers only')
            exit(1)
        available_gpus = torch.cuda.device_count()
        for dev_id in args.gpus:
            if dev_id >= available_gpus:
                msglogger.error('ERROR: GPU device ID {0} requested, but only {1} devices available'
                                .format(dev_id, available_gpus))
                exit(1)
        # Set default device in case the first one on the list != 0
        torch.cuda.set_device(args.gpus[0])

    # Infer the dataset from the model name
    args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet'
    args.num_classes = 10 if args.dataset == 'cifar10' else 1000

    if args.earlyexit_thresholds:
        args.num_exits = len(args.earlyexit_thresholds) + 1
        args.loss_exits = [0] * args.num_exits
        args.losses_exits = []
        args.exiterrors = []

    # Create the model
    model = create_model(args.pretrained, args.dataset, args.arch, device_ids=args.gpus)
    compression_scheduler = None
    # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
    # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
    tflogger = TensorBoardLogger(msglogger.logdir)
    pylogger = PythonLogger(msglogger)

    # capture thresholds for early-exit training
    if args.earlyexit_thresholds:
        msglogger.info('=> using early-exit threshold values of %s', args.earlyexit_thresholds)

    # We can optionally resume from a checkpoint
    if args.resume:
        model, compression_scheduler, start_epoch = apputils.load_checkpoint(
            model, chkpt_file=args.resume)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    msglogger.info('Optimizer Type: %s', type(optimizer))
    msglogger.info('Optimizer Args: %s', optimizer.defaults)

    if args.ADC:
        return automated_deep_compression(model, criterion, pylogger, args)

    # This sample application can be invoked to produce various summary reports.
    if args.summary:
        return summarize_model(model, args.dataset, which_summary=args.summary)

    # Load the datasets: the dataset to load is inferred from the model name passed
    # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
    # substring "_cifar", then cifar10 is used.
    train_loader, val_loader, test_loader, _ = apputils.load_data(
        args.dataset, os.path.expanduser(args.data), args.batch_size,
        args.workers, args.validation_size, args.deterministic)
    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                   len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))

    activations_sparsity = None
    if args.activation_stats:
        # If your model has ReLU layers, then those layers have sparse activations.
        # ActivationSparsityCollector will collect information about this sparsity.
        # WARNING! Enabling activation sparsity collection will significantly slow down training!
        activations_sparsity = ActivationSparsityCollector(model)

    if args.sensitivity is not None:
        return sensitivity_analysis(model, criterion, test_loader, pylogger, args)

    if args.evaluate:
        return evaluate_model(model, criterion, test_loader, pylogger, args)

    if args.compress:
        # The main use-case for this sample application is CNN compression. Compression
        # requires a compression schedule configuration file in YAML.
        compression_scheduler = distiller.file_config(model, optimizer, args.compress)
        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
        model.cuda()
    else:
        compression_scheduler = distiller.CompressionScheduler(model)

    args.kd_policy = None
    if args.kd_teacher:
        teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
        if args.kd_resume:
            teacher, _, _ = apputils.load_checkpoint(teacher, chkpt_file=args.kd_resume)
        dlw = distiller.DistillationLossWeights(args.kd_distill_wt, args.kd_student_wt, args.kd_teacher_wt)
        args.kd_policy = distiller.KnowledgeDistillationPolicy(model, teacher, args.kd_temp, dlw)
        compression_scheduler.add_policy(args.kd_policy, starting_epoch=args.kd_start_epoch, ending_epoch=args.epochs,
                                         frequency=1)

        msglogger.info('\nStudent-Teacher knowledge distillation enabled:')
        msglogger.info('\tTeacher Model: %s', args.kd_teacher)
        msglogger.info('\tTemperature: %s', args.kd_temp)
        msglogger.info('\tLoss Weights (distillation | student | teacher): %s',
                       ' | '.join(['{:.2f}'.format(val) for val in dlw]))
        msglogger.info('\tStarting from Epoch: %s', args.kd_start_epoch)

    for epoch in range(start_epoch, start_epoch + args.epochs):
        # This is the main training loop.
        msglogger.info('\n')
        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
              loggers=[tflogger, pylogger], args=args)
        distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
        if args.activation_stats:
            distiller.log_activation_sparsity(epoch, loggers=[tflogger, pylogger],
                                              collector=activations_sparsity)

        # evaluate on validation set
        top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch)
        stats = ('Peformance/Validation/',
                 OrderedDict([('Loss', vloss),
                              ('Top1', top1),
                              ('Top5', top5)]))
        distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1,
                                        loggers=[tflogger])

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch, optimizer)

        # remember best top1 and save checkpoint
        is_best = top1 > best_top1
        if is_best:
            best_epoch = epoch
            best_top1 = top1
        msglogger.info('==> Best Top1: %.3f   On Epoch: %d\n', best_top1, best_epoch)
        apputils.save_checkpoint(epoch, args.arch, model, optimizer, compression_scheduler, best_top1, is_best,
                                 args.name, msglogger.logdir)

    # Finally run results on the test set
    test(test_loader, model, criterion, [pylogger], args=args)
예제 #4
0
파일: main.py 프로젝트: zxshinxz/distiller
def export_onnx(path, batch_size, seq_len):
    msglogger.info('The model is also exported in ONNX format at {}'.format(
        os.path.realpath(args.onnx_export)))
    model.eval()
    dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(
        -1, batch_size).to(device)
    hidden = model.init_hidden(batch_size)
    torch.onnx.export(model, (dummy_input, hidden), path)


# Distiller loggers
msglogger = apputils.config_pylogger('logging.conf', None)
tflogger = TensorBoardLogger(msglogger.logdir)
tflogger.log_gradients = True
pylogger = PythonLogger(msglogger)

if args.summary:
    which_summary = args.summary
    if which_summary == 'png':
        draw_lang_model_to_file(model, 'rnn.png', 'wikitext2')
    elif which_summary == 'percentile':
        percentile = 0.9
        for name, param in model.state_dict().items():
            if param.dim() < 2:
                # Skip biases
                continue
            bottomk, _ = torch.topk(param.abs().view(-1),
                                    int(percentile * param.numel()),
                                    largest=False,
                                    sorted=True)
예제 #5
0
def train(c, net, compression_scheduler=None):
    import distiller.apputils as apputils
    from distiller.data_loggers import TensorBoardLogger, PythonLogger
    msglogger = apputils.config_pylogger('logging.conf', None)
    tflogger = TensorBoardLogger(msglogger.logdir)
    tflogger.log_gradients = True
    pylogger = PythonLogger(msglogger)
    c.setdefault(hebbian=False)

    emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)
    step_lr = scheduler(c, opt, step)
    data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train')
    iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')

    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    c.log('Embedding has %s parameters' % emb_params)

    if c.get("steps_per_epoch"):
        steps_per_epoch = c.steps_per_epoch
    else:
        steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk
    print("#### steps per epoch %d ####" % steps_per_epoch)

    if c.hebbian:
        counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])]
        temp_counters = [torch.zeros_like(x) for x in counters]

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        while step < s.step_max:
            batch = step % steps_per_epoch
            epoch = step // steps_per_epoch
            if step % steps_per_epoch == 0:
                c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step))
                if compression_scheduler:
                    compression_scheduler.on_epoch_begin(epoch)

            if compression_scheduler:
                compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

            step_lr(step)

            x = to_torch(next(iter_tr), c.device).t()

            t_s = time()
            inputs, labels = x[:-1], x[1:]
            preds = net(inputs, labels)
            loss = preds['loss']

            if compression_scheduler:
                _  = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
                                                           minibatches_per_epoch=steps_per_epoch,
                                                           loss=loss, return_loss_components=False)

            opt.zero_grad()
            if torch.isnan(loss):
                raise RuntimeError('Encountered nan loss during training')
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5))
            opt.step()

            if c.hebbian:
                hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters)

            time_model = np.round(time() - t_s, 5)

            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e ** loss
            step_result = pd.Series(dict(
                loss=loss,
                perplexity=perplexity,
                time=time_model
            )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']
            if c.use_cache:
                step_result['theta'] = preds['theta']
                step_result['lambda'] = preds['lambda'].item()

            if compression_scheduler:
                compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

            if step % steps_per_epoch == 0:
                if compression_scheduler:
                    compression_scheduler.on_epoch_end(epoch)

            s.step = step = step + 1
            if step % c.step_eval == 0:
                distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger])
                t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True)
                c.log("total sparsity: %.3lf" % total)

                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
                )
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result
            c.on_step_end(s)
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)
예제 #6
0
def train(args):

    SRC = data.Field(tokenize=tokenize_de,
                     pad_token=BLANK_WORD,
                     lower=args.lower)
    TGT = data.Field(tokenize=tokenize_en,
                     init_token=BOS_WORD,
                     eos_token=EOS_WORD,
                     pad_token=BLANK_WORD,
                     lower=args.lower)

    # Load IWSLT Data ---> German to English Translation
    if args.dataset == 'IWSLT':
        train, val, test = datasets.IWSLT.splits(
            exts=('.de', '.en'),
            fields=(SRC, TGT),
            filter_pred=lambda x: len(vars(x)['src']) <= args.max_length and
            len(vars(x)['trg']) <= args.max_length)
    else:
        train, val, test = datasets.Multi30k.splits(
            exts=('.de', '.en'),
            fields=(SRC, TGT),
            filter_pred=lambda x: len(vars(x)['src']) <= args.max_length and
            len(vars(x)['trg']) <= args.max_length)

    # Frequency of words in the vocabulary
    SRC.build_vocab(train.src, min_freq=args.min_freq)
    TGT.build_vocab(train.trg, min_freq=args.min_freq)

    print("Size of source vocabulary:", len(SRC.vocab))
    print("Size of target vocabulary:", len(TGT.vocab))

    pad_idx = TGT.vocab.stoi[BLANK_WORD]
    model = make_model(len(SRC.vocab),
                       len(TGT.vocab),
                       n=args.num_blocks,
                       d_model=args.hidden_dim,
                       d_ff=args.ff_dim,
                       h=args.num_heads,
                       dropout=args.dropout)
    print("Model made with n:", args.num_blocks, "hidden_dim:",
          args.hidden_dim, "feed forward dim:", args.ff_dim, "heads:",
          args.num_heads, "dropout:", args.dropout)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Number of parameters: ", params)

    if args.load_model:
        print("Loading model from [%s]" % args.load_model)
        model.load_state_dict(torch.load(args.load_model))

    # UNCOMMENT WHEN RUNNING ON RESEARCH MACHINES - run on GPU
    # model.cuda()

    # Used by original authors, hurts perplexity but improves BLEU score
    criterion = LabelSmoothing(size=len(TGT.vocab),
                               padding_idx=pad_idx,
                               smoothing=0.1)

    # UNCOMMENT WHEN RUNNING ON RESEARCH MACHINES - run on GPU
    # criterion.cuda()

    train_iter = MyIterator(train,
                            batch_size=args.batch_size,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=True)
    valid_iter = MyIterator(val,
                            batch_size=args.batch_size,
                            device=0,
                            repeat=False,
                            sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn,
                            train=False,
                            sort=False)
    model_par = nn.DataParallel(model, device_ids=devices)

    # model_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,
    #                     torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    # Use standard optimizer -- As used in the paper
    model_opt = get_std_opt(model)

    # PRUNING CODE
    if args.summary:
        df = distiller.weights_sparsity_tbl_summary(model, False)
        print(df)
        exit(0)

    msglogger = apputils.config_pylogger('logging.conf', None)
    tflogger = TensorBoardLogger(msglogger.logdir)
    tflogger.log_gradients = True
    pylogger = PythonLogger(msglogger)

    source = args.compress

    if args.compress:
        compression_scheduler = distiller.config.file_config(
            model_par.module, None, args.compress)

    print(model_par.module)

    best_bleu = 0
    best_epoch = 0

    steps_per_epoch = math.ceil(len(train_iter.data()) / 60)

    for epoch in range(args.epoch):
        print("=" * 80)
        print("Epoch ", epoch + 1)
        print("=" * 80)
        print("Training...")
        model_par.train()

        if compression_scheduler:
            compression_scheduler.on_epoch_begin(epoch)

        # IF PRUNING
        run_epoch((rebatch(pad_idx, b) for b in train_iter),
                  model_par,
                  MultiGPULossCompute(model.generator,
                                      criterion,
                                      devices=devices,
                                      opt=model_opt),
                  args,
                  epoch,
                  steps_per_epoch,
                  compression_scheduler,
                  SRC,
                  TGT,
                  valid_iter,
                  is_valid=False)

        # run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par,
        #           MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt), args,
        #           SRC, TGT, valid_iter, is_valid=False)

        print("Validation...")
        model_par.eval()

        # IF PRUNING
        loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter),
                         model_par,
                         MultiGPULossCompute(model.generator,
                                             criterion,
                                             devices=devices,
                                             opt=None),
                         args,
                         epoch,
                         steps_per_epoch,
                         compression_scheduler,
                         SRC,
                         TGT,
                         valid_iter,
                         is_valid=True)

        # loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter), model_par,
        #                  MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None), args,
        #                  SRC, TGT, valid_iter, is_valid=True)

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch)

        print('Validation loss:', loss)
        print('Validation perplexity: ', np.exp(loss))
        bleu_score = run_validation_bleu_score(model, SRC, TGT, valid_iter)

        if best_bleu < bleu_score:
            best_bleu = bleu_score
            model_file = args.save_to + args.exp_name + 'validation.bin'
            print('Saving model without optimizer [%s]' % model_file)
            torch.save(model_par.module.state_dict(), model_file)
            best_epoch = epoch

        model_file = args.save_to + args.exp_name + 'latest.bin'
        print('Saving latest model without optimizer [%s]' % model_file)
        torch.save(model_par.module.state_dict(), model_file)

    print('The best epoch was:', best_epoch)