def main():
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()

    weight_bits = int(args.weight_bits)
    activ_bits = int(args.activ_bits)

    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    writer = SummaryWriter()

    if 'cuda' in args.type:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        print('Selected GPUs: ', args.gpus)
        # torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    if args.model == 'mobilenet':
        model = models.__dict__[args.model]
        model = model(**model_config)
    elif args.model == 'mobilenetv2':
        model = torch.hub.load('pytorch/vision:v0.6.0',
                               'mobilenet_v2',
                               pretrained=True)
    elif args.model == 'resnet18':
        model = torch.hub.load('pytorch/vision:v0.6.0',
                               'resnet18',
                               pretrained=True)
    else:  #if args.model == 'mobilenet_v3':
        model = models.mobilenetv3_large(
            width_mult=float(args.mobilenet_width))
        model.load_state_dict(
            torch.load(
                "models/mobilenet_v3/mobilenetv3-large-0.75-9632d2a8.pth"))
    nClasses = get_num_classes(args.dataset)
    model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': nClasses, \
                    'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    logging.info("created model with configuration: %s", model_config)
    print(model)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })
    print(transform)
    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    fast_val_loader = torch.utils.data.DataLoader(
        val_data,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        sampler=torch.utils.data.RandomSampler(val_data,
                                               replacement=True,
                                               num_samples=1000))

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    fast_train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        sampler=torch.utils.data.RandomSampler(val_data,
                                               replacement=True,
                                               num_samples=100000))

    #define optimizer
    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        if 'alpha' in key or 'beta' in key:
            params += [{'params': value, 'weight_decay': 1e-4}]
        else:
            params += [{'params': value, 'weight_decay': 1e-5}]

    mixed_prec_dict = None
    if args.mixed_prec_dict is not None:
        mixed_prec_dict = nemo.utils.precision_dict_from_json(
            args.mixed_prec_dict)
        print("Load mixed precision dict from outside")
    elif args.mem_constraint is not '':
        mem_contraints = json.loads(args.mem_constraint)
        print('This is the memory constraint:', mem_contraints)
        if mem_contraints is not None:
            x_test = torch.Tensor(1, 3, 224, 224)
            mixed_prec_dict = memory_driven_quant(model,
                                                  x_test,
                                                  mem_contraints[0],
                                                  mem_contraints[1],
                                                  args.mixed_prec_quant,
                                                  use_sawb=args.use_sawb)

    #multi gpus
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()

    # mobilenet_width = float(args.mobilenet_width)
    # mobilenet_width_s = args.mobilenet_width
    # mobilenet_input = int(args.mobilenet_input)

    if args.resume is None:
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  0, None)
        print("[NEMO] Full-precision model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

    if args.quantize:

        # transform the model in a NEMO FakeQuantized representation
        model = nemo.transform.quantize_pact(model,
                                             dummy_input=torch.randn(
                                                 (1, 3, 224, 224)).to('cuda'))

        if args.resume is not None:
            checkpoint_file = args.resume
            if os.path.isfile(checkpoint_file):
                logging.info("loading checkpoint '%s'", args.resume)
                checkpoint_loaded = torch.load(checkpoint_file)
                checkpoint = checkpoint_loaded['state_dict']
                model.load_state_dict(checkpoint, strict=True)
                prec_dict = checkpoint_loaded.get('precision')
            else:
                logging.error("no checkpoint found at '%s'", args.resume)
                import sys
                sys.exit(1)

        if args.resume is None:
            print("[NEMO] Model calibration")
            model.change_precision(bits=20)
            model.reset_alpha_weights()

            if args.initial_folding:
                model.fold_bn()
                # use DFQ for weight equalization
                if args.initial_equalization:
                    model.equalize_weights_dfq()
            elif args.initial_equalization:
                model.equalize_weights_lsq(verbose=True)
                model.reset_alpha_weights()


#                model.reset_alpha_weights(use_method='dyn_range', dyn_range_cutoff=0.05, verbose=True)

# calibrate after equalization
            with model.statistics_act():
                val_loss, val_prec1, val_prec5 = validate(
                    val_loader, model, criterion, 0, None)
            model.reset_alpha_act()

            val_loss, val_prec1, val_prec5 = validate(val_loader, model,
                                                      criterion, 0, None)

            print("[NEMO] 20-bit calibrated model: top-1=%.2f top-5=%.2f" %
                  (val_prec1, val_prec5))
            nemo.utils.save_checkpoint(model,
                                       None,
                                       0,
                                       acc=val_prec1,
                                       checkpoint_name='resnet18_calibrated',
                                       checkpoint_suffix=args.suffix)

            model.change_precision(bits=activ_bits)
            model.change_precision(bits=weight_bits, scale_activations=False)

            # init weight clipping parameters to their reset value and disable their gradient
            model.reset_alpha_weights()
            if args.use_sawb:
                model.disable_grad_sawb()
                model.weight_clip_sawb()

            mixed_prec_dict_all = model.export_precision()
            mixed_prec_dict_all['relu']['x_bits'] = 2
            mixed_prec_dict_all['layer1.0.relu']['x_bits'] = 4
            mixed_prec_dict_all['layer3.1.conv1']['W_bits'] = 4
            mixed_prec_dict_all['layer3.1.conv2']['W_bits'] = 4
            mixed_prec_dict_all['layer4.0.conv1']['W_bits'] = 2
            mixed_prec_dict_all['layer4.0.conv2']['W_bits'] = 2
            mixed_prec_dict_all['layer4.1.conv1']['W_bits'] = 2
            mixed_prec_dict_all['layer4.1.conv2']['W_bits'] = 2
            model.change_precision(bits=1, min_prec_dict=mixed_prec_dict_all)

        else:
            print("[NEMO] Not calibrating model, as it is pretrained")
            model.change_precision(bits=1, min_prec_dict=prec_dict)

    optimizer = torch.optim.Adam([
        {
            'params': model.get_nonclip_parameters(),
            'lr': args.lr,
            'weight_decay': 1e-5
        },
        {
            'params': model.get_clip_parameters(),
            'lr': args.lr,
            'weight_decay': 0.001
        },
    ])

    reset_grad_flow(model, __global_ave_grads, __global_max_grads)
    for epoch in range(args.start_epoch, args.epochs):
        #        optimizer = adjust_optimizer(optimizer, epoch, regime)

        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(
            train_loader,
            model,
            criterion,
            epoch,
            optimizer,
            freeze_bn=True if epoch > 0 else False,
            absorb_bn=True if epoch == 0 else False,
            writer=writer)
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        writer.add_scalar('Loss/val', val_loss, epoch * len(train_loader))
        writer.add_scalar('Accuracy/val', val_prec1, epoch * len(train_loader))

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)

        #save_model
        if args.save_check:
            nemo.utils.save_checkpoint(
                model,
                optimizer,
                0,
                acc=val_prec1,
                checkpoint_name='resnet18%s_checkpoint' %
                ("_mixed" if mixed_prec_dict is not None else ""),
                checkpoint_suffix=args.suffix)

        if is_best:
            nemo.utils.save_checkpoint(
                model,
                optimizer,
                0,
                acc=val_prec1,
                checkpoint_name='resnet18%s_best' %
                ("_mixed" if mixed_prec_dict is not None else ""),
                checkpoint_suffix=args.suffix)

        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \t'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.save()
예제 #2
0
def main():
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()

    weight_bits = int(args.weight_bits)
    activ_bits = int(args.activ_bits)

    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.type:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        print('Selected GPUs: ', args.gpus)
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    nClasses = get_num_classes(args.dataset)
    model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': nClasses, \
                    'type_quant': args.type_quant, 'weight_bits': weight_bits, 'activ_bits': activ_bits,\
                    'activ_type': args.activ_type, 'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    print(model)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })
    print(transform)
    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.quantizer:
        val_quant_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=32,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    #define optimizer
    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        if 'clip_val' in key:
            params += [{'params': value, 'weight_decay': 1e-4}]
        else:
            params += [{'params': value}]
    optimizer = torch.optim.SGD(params, lr=0.1)
    logging.info('training regime: %s', regime)

    #define quantizer
    if args.quantizer:
        if args.mem_constraint is not '':
            mem_contraints = json.loads(args.mem_constraint)
            print('This is the memory constraint:', mem_contraints)
            if mem_contraints is not None:
                x_test = torch.Tensor(1, 3, args.mobilenet_input,
                                      args.mobilenet_input)
                add_config = memory_driven_quant(model, x_test,
                                                 mem_contraints[0],
                                                 mem_contraints[1],
                                                 args.mixed_prec_quant)
                if add_config == -1:
                    print('The quantization process failed!')
            else:
                add_config = []
        else:
            mem_constraint = None
            if args.quant_add_config is not '':
                add_config = json.loads(args.quant_add_config)

            else:
                add_config = []

        quantizer = quantization.QuantOp(model, args.type_quant, weight_bits, \
            batch_fold_type=args.batch_fold_type, batch_fold_delay=args.batch_fold_delay, act_bits=activ_bits, \
            add_config = add_config )
        quantizer.deployment_model.type(args.type)
        quantizer.add_params_to_optimizer(optimizer)

    else:
        quantizer = None

    #exit(0)

    #multi gpus
    if args.gpus and len(args.gpus) > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.type(args.type)

    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint_loaded = torch.load(checkpoint_file)
            checkpoint = checkpoint_loaded['state_dict']
            model.load_state_dict(checkpoint, strict=False)
            print('Model pretrained')
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    if args.quantizer:
        quantizer.init_parameters()

    if args.evaluate:
        # evaluate on validation set

        if args.quantizer:
            # evaluate deployment model on validation set
            quantizer.generate_deployment_model()
            val_quant_loss, val_quant_prec1, val_quant_prec5 = validate(
                val_quant_loader, quantizer.deployment_model, criterion, 0,
                'deployment')
        else:
            val_quant_loss, val_quant_prec1, val_quant_prec5 = 0, 0, 0

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  0, quantizer)

        logging.info('\n This is the results from evaluation only: '
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \t'
                     'Validation Quant Prec@1 {val_quant_prec1:.3f} \t'
                     'Validation Quant Prec@5 {val_quant_prec5:.3f} \n'.format(
                         val_prec1=val_prec1,
                         val_prec5=val_prec5,
                         val_quant_prec1=val_quant_prec1,
                         val_quant_prec5=val_quant_prec5))
        exit(0)

    for epoch in range(args.start_epoch, args.epochs):
        optimizer = adjust_optimizer(optimizer, epoch, regime)

        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer, quantizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch, quantizer)

        if args.quantizer:
            # evaluate deployment model on validation set
            quantizer.generate_deployment_model()
            val_quant_loss, val_quant_prec1, val_quant_prec5 = validate(
                val_quant_loader, quantizer.deployment_model, criterion, epoch,
                'deployment')
        else:
            val_quant_loss, val_quant_prec1, val_quant_prec5 = 0, 0, 0

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)

        #save_model
        if args.save_check:

            print('Saving Model!! Accuracy : ', best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'config': model_config,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'regime': regime,
                    'quantizer': quantizer,
                    'add_config': add_config,
                    'fold_type': args.batch_fold_type
                },
                is_best,
                path=save_path)

        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \t'
                     'Validation Quant Prec@1 {val_quant_prec1:.3f} \t'
                     'Validation Quant Prec@5 {val_quant_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5,
                         val_quant_prec1=val_quant_prec1,
                         val_quant_prec5=val_quant_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5,
                    val_quant_error1=100 - val_quant_prec1,
                    val_quant_error5=100 - val_quant_prec5)
        results.save()
def main():
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()

    weight_bits = int(args.weight_bits)
    activ_bits = int(args.activ_bits)

    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.type:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        print('Selected GPUs: ', args.gpus)
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    if args.model == 'mobilenet':
        model = models.__dict__[args.model]
    elif args.model == 'mobilenetv2':
        model = torch.hub.load('pytorch/vision:v0.6.0',
                               'mobilenet_v2',
                               pretrained=True)
    else:  #if args.model == 'mobilenet_v3':
        model = models.mobilenetv3_large(
            width_mult=float(args.mobilenet_width))
        model.load_state_dict(
            torch.load(
                "models/mobilenet_v3/mobilenetv3-large-0.75-9632d2a8.pth"))
    nClasses = get_num_classes(args.dataset)
    model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': nClasses, \
                    'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    print(model)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })
    print(transform)
    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    #define optimizer
    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        if 'alpha' in key or 'beta' in key:
            params += [{'params': value, 'weight_decay': 1e-4}]
        else:
            params += [{'params': value, 'weight_decay': 1e-5}]

    mixed_prec_dict = None
    if args.mixed_prec_dict is not None:
        mixed_prec_dict = nemo.utils.precision_dict_from_json(
            args.mixed_prec_dict)
        print("Load mixed precision dict from outside")
    elif args.mem_constraint is not '':
        mem_contraints = json.loads(args.mem_constraint)
        print('This is the memory constraint:', mem_contraints)
        if mem_contraints is not None:
            x_test = torch.Tensor(1, 3, args.mobilenet_input,
                                  args.mobilenet_input)
            mixed_prec_dict = memory_driven_quant(model, x_test,
                                                  mem_contraints[0],
                                                  mem_contraints[1],
                                                  args.mixed_prec_quant)

    #multi gpus
    if args.gpus and len(args.gpus) > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.type(args.type)

    mobilenet_width = float(args.mobilenet_width)
    mobilenet_width_s = args.mobilenet_width
    mobilenet_input = int(args.mobilenet_input)

    if args.resume is None:
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  0, None)
        print("[NEMO] Full-precision model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

    if args.quantize:

        # transform the model in a NEMO FakeQuantized representation
        model = nemo.transform.quantize_pact(model,
                                             dummy_input=torch.randn(
                                                 (1, 3, mobilenet_input,
                                                  mobilenet_input)).to('cuda'))
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=1e-5)

        if args.resume is not None:
            checkpoint_file = args.resume
            if os.path.isfile(checkpoint_file):
                logging.info("loading checkpoint '%s'", args.resume)
                checkpoint_loaded = torch.load(checkpoint_file)
                checkpoint = checkpoint_loaded['state_dict']
                model.load_state_dict(checkpoint, strict=True)
                prec_dict = checkpoint_loaded.get('precision')
            else:
                logging.error("no checkpoint found at '%s'", args.resume)
                import sys
                sys.exit(1)

        if args.resume is None:
            print("[NEMO] Model calibration")
            model.change_precision(bits=20)
            model.reset_alpha_weights()

            if args.initial_folding:
                model.fold_bn()
                # use DFQ for weight equalization
                if args.initial_equalization:
                    model.equalize_weights_dfq()
            elif args.initial_equalization:
                model.equalize_weights_lsq(verbose=True)
                model.reset_alpha_weights()
#                model.reset_alpha_weights(use_method='dyn_range', dyn_range_cutoff=0.05, verbose=True)

# calibrate after equalization
            with model.statistics_act():
                val_loss, val_prec1, val_prec5 = validate(
                    val_loader, model, criterion, 0, None)

            # # use this in place of the usual calibration, because PACT_Act's descend from ReLU6 and
            # # the trained weights already assume the presence of a clipping effect
            # # this should be integrated in NEMO by saving the "origin" of the PACT_Act!
            # for i in range(0,27):
            #     model.model[i][3].alpha.data[:] = min(model.model[i][3].alpha.item(), model.model[i][3].max)

            val_loss, val_prec1, val_prec5 = validate(val_loader, model,
                                                      criterion, 0, None)

            print("[NEMO] 20-bit calibrated model: top-1=%.2f top-5=%.2f" %
                  (val_prec1, val_prec5))
            nemo.utils.save_checkpoint(
                model,
                optimizer,
                0,
                acc=val_prec1,
                checkpoint_name='mobilenet_%s_%d_calibrated' %
                (mobilenet_width_s, mobilenet_input),
                checkpoint_suffix=args.suffix)

            model.change_precision(bits=activ_bits)
            model.change_precision(bits=weight_bits, scale_activations=False)
            import IPython
            IPython.embed()

        else:
            print("[NEMO] Not calibrating model, as it is pretrained")
            model.change_precision(bits=1, min_prec_dict=prec_dict)

            ### val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion, 0, None)
            ### print("[NEMO] pretrained model: top-1=%.2f top-5=%.2f" % (val_prec1, val_prec5))

        if mixed_prec_dict is not None:
            mixed_prec_dict_all = model.export_precision()
            for k in mixed_prec_dict.keys():
                mixed_prec_dict_all[k] = mixed_prec_dict[k]
            model.change_precision(bits=1, min_prec_dict=mixed_prec_dict_all)

            # freeze and quantize BN parameters
            # nemo.transform.bn_quantizer(model, precision=nemo.precision.Precision(bits=20))
            # model.freeze_bn()
            # model.fold_bn()
            # model.equalize_weights_dfq(verbose=True)
            val_loss, val_prec1, val_prec5 = validate(val_loader, model,
                                                      criterion, 0, None)


#            print("[NEMO] Rounding weights")
#            model.round_weights()

    if args.pure_export:
        model.freeze_bn(reset_stats=True, disable_grad=True)
        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  shorten=10)
        print("[NEMO] FQ model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))
        input_bias_dict = {'model.0.0': +1.0, 'model.0.1': +1.0}
        remove_bias_dict = {'model.0.1': 'model.0.2'}
        input_bias = math.floor(1.0 / (2. / 255)) * (2. / 255)
        model.qd_stage(eps_in=2. / 255,
                       add_input_bias_dict=input_bias_dict,
                       remove_bias_dict=remove_bias_dict,
                       int_accurate=True)
        model.model[0][0].value = input_bias
        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  input_bias=input_bias,
                                                  eps_in=2. / 255,
                                                  mode='qd',
                                                  shorten=10)
        print("[NEMO] QD model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))
        model.id_stage()
        model.model[0][0].value = input_bias * (255. / 2)
        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  input_bias=input_bias,
                                                  eps_in=2. / 255,
                                                  mode='id',
                                                  shorten=10)
        print("[NEMO] ID model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))
        nemo.utils.export_onnx('mobilenet_%s_%d.onnx' %
                               (mobilenet_width_s, mobilenet_input),
                               model,
                               model, (3, mobilenet_input, mobilenet_input),
                               perm=None)
        import sys
        sys.exit(0)

    if args.terminal:
        fqs = copy.deepcopy(model.state_dict())
        model.freeze_bn(reset_stats=True, disable_grad=True)
        bin_fq, bout_fq, _ = nemo.utils.get_intermediate_activations(
            model, validate, val_loader, model, criterion, 0, None, shorten=1)

        torch.save({'in': bin_fq['model.0.0'][0]}, "input_fq.pth")

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  0, None)
        print("[NEMO] FQ model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

        input_bias_dict = {'model.0.0': +1.0, 'model.0.1': +1.0}
        remove_bias_dict = {'model.0.1': 'model.0.2'}
        input_bias = math.floor(1.0 / (2. / 255)) * (2. / 255)

        model.qd_stage(eps_in=2. / 255,
                       add_input_bias_dict=input_bias_dict,
                       remove_bias_dict=remove_bias_dict,
                       int_accurate=True)

        # fix ConstantPad2d
        model.model[0][0].value = input_bias

        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  input_bias=input_bias,
                                                  eps_in=2. / 255,
                                                  mode='qd',
                                                  shorten=50)
        print("[NEMO] QD model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

        qds = copy.deepcopy(model.state_dict())
        bin_qd, bout_qd, _ = nemo.utils.get_intermediate_activations(
            model,
            validate,
            val_loader,
            model,
            criterion,
            0,
            None,
            input_bias=input_bias,
            eps_in=2. / 255,
            mode='qd',
            shorten=1)

        torch.save({'qds': qds, 'fqs': fqs}, "states.pth")
        torch.save({'in': bin_qd['model.0.0'][0]}, "input_qd.pth")

        diff = collections.OrderedDict()
        for k in bout_fq.keys():
            diff[k] = (bout_fq[k] - bout_qd[k]).to('cpu').abs()

        for i in range(0, 26):
            for j in range(3, 4):
                k = 'model.%d.%d' % (i, j)
                kn = 'model.%d.%d' % (i if j < 3 else i + 1,
                                      j + 1 if j < 3 else 0)
                eps = model.get_eps_at(kn, eps_in=2. / 255)[0]
                print("%s:" % k)
                idx = diff[k] > eps
                n = idx.sum()
                t = (diff[k] > -1e9).sum()
                max_eps = torch.ceil(
                    diff[k].max() /
                    model.get_eps_at('model.%d.0' %
                                     (i + 1), 2. / 255)[0]).item()
                mean_eps = torch.ceil(
                    diff[k][idx].mean() /
                    model.get_eps_at('model.%d.0' %
                                     (i + 1), 2. / 255)[0]).item()
                try:
                    print("  max:   %.3f (%d eps)" %
                          (diff[k].max().item(), max_eps))
                    print("  mean:  %.3f (%d eps) (only diff. elements)" %
                          (diff[k][idx].mean().item(), mean_eps))
                    print("  #diff: %d/%d (%.1f%%)" %
                          (n, t, float(n) / float(t) * 100))
                except ValueError:
                    print("  #diff: 0/%d (0%%)" % (t, ))

        model.id_stage()
        # fix ConstantPad2d
        model.model[0][0].value = input_bias * (255. / 2)

        ids = model.state_dict()
        bin_id, bout_id, _ = nemo.utils.get_intermediate_activations(
            model,
            validate,
            val_loader,
            model,
            criterion,
            0,
            None,
            input_bias=input_bias,
            eps_in=2. / 255,
            mode='id',
            shorten=1)

        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  input_bias=input_bias,
                                                  eps_in=2. / 255,
                                                  mode='id',
                                                  shorten=50)
        print("[NEMO] ID model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

        try:
            os.makedirs("golden")
        except Exception:
            pass

        torch.save({'in': bin_fq['model.0.0'][0]}, "input_id.pth")

        diff = collections.OrderedDict()
        for i in range(0, 26):
            for j in range(3, 4):
                k = 'model.%d.%d' % (i, j)
                kn = 'model.%d.%d' % (i if j < 3 else i + 1,
                                      j + 1 if j < 3 else 0)
                eps = model.get_eps_at(kn, eps_in=2. / 255)[0]
                diff[k] = (bout_id[k] * eps - bout_qd[k]).to('cpu').abs()
                print("%s:" % k)
                idx = diff[k] >= eps
                n = idx.sum()
                t = (diff[k] > -1e9).sum()
                max_eps = torch.ceil(diff[k].max() / eps).item()
                mean_eps = torch.ceil(diff[k][idx].mean() / eps).item()
                try:
                    print("  max:   %.3f (%d eps)" %
                          (diff[k].max().item(), max_eps))
                    print("  mean:  %.3f (%d eps) (only diff. elements)" %
                          (diff[k][idx].mean().item(), mean_eps))
                    print("  #diff: %d/%d (%.1f%%)" %
                          (n, t, float(n) / float(t) * 100))
                except ValueError:
                    print("  #diff: 0/%d (0%%)" % (t, ))
        import IPython
        IPython.embed()

        bidx = 0
        for n, m in model.named_modules():
            try:
                actbuf = bin_id[n][0][bidx].permute((1, 2, 0))
            except RuntimeError:
                actbuf = bin_id[n][0][bidx]
            np.savetxt("golden/golden_input_%s.txt" % n,
                       actbuf.cpu().detach().numpy().flatten(),
                       header="input (shape %s)" % (list(actbuf.shape)),
                       fmt="%.3f",
                       delimiter=',',
                       newline=',\n')
        for n, m in model.named_modules():
            try:
                actbuf = bout_id[n][bidx].permute((1, 2, 0))
            except RuntimeError:
                actbuf = bout_id[n][bidx]
            np.savetxt("golden/golden_%s.txt" % n,
                       actbuf.cpu().detach().numpy().flatten(),
                       header="%s (shape %s)" % (n, list(actbuf.shape)),
                       fmt="%.3f",
                       delimiter=',',
                       newline=',\n')
        nemo.utils.export_onnx("model_int.onnx",
                               model,
                               model, (3, 224, 224),
                               perm=None)

        val_loss, val_prec1, val_prec5 = validate(val_loader,
                                                  model,
                                                  criterion,
                                                  0,
                                                  None,
                                                  input_bias=input_bias,
                                                  eps_in=2. / 255)
        print("[NEMO] ID model: top-1=%.2f top-5=%.2f" %
              (val_prec1, val_prec5))

        import IPython
        IPython.embed()
        import sys
        sys.exit(0)

    for epoch in range(args.start_epoch, args.epochs):
        #        optimizer = adjust_optimizer(optimizer, epoch, regime)

        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(
            train_loader,
            model,
            criterion,
            epoch,
            optimizer,
            freeze_bn=True if epoch > 0 else False,
            absorb_bn=True if epoch == 0 else False)
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)

        #save_model
        if args.save_check:
            nemo.utils.save_checkpoint(
                model,
                optimizer,
                0,
                acc=val_prec1,
                checkpoint_name='mobilenet_%s_%d%s_checkpoint' %
                (mobilenet_width_s, mobilenet_input,
                 "_mixed" if mixed_prec_dict is not None else ""),
                checkpoint_suffix=args.suffix)

        if is_best:
            nemo.utils.save_checkpoint(
                model,
                optimizer,
                0,
                acc=val_prec1,
                checkpoint_name='mobilenet_%s_%d%s_best' %
                (mobilenet_width_s, mobilenet_input,
                 "_mixed" if mixed_prec_dict is not None else ""),
                checkpoint_suffix=args.suffix)

        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \t'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.save()