def _gen_model(self):
     self.net = mobilenet_v1(num_classes=10, pretrained=False)
     self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False)
     self.net_mxn = mobilenet_v1(num_classes=10, pretrained=False)
     self.pruner = UnstructuredPruner(self.net,
                                      mode='ratio',
                                      ratio=0.55,
                                      local_sparsity=True)
     self.pruner_conv1x1 = UnstructuredPruner(
         self.net_conv1x1,
         mode='ratio',
         ratio=0.55,
         prune_params_type='conv1x1_only',
         local_sparsity=False)
     self.pruner_mxn = UnstructuredPruner(self.net_mxn,
                                          mode='ratio',
                                          ratio=0.55,
                                          local_sparsity=True,
                                          sparse_block=[2, 1])
Esempio n. 2
0
 def test_case5(self):
     paddle.disable_static()
     model = mobilenet_v1()
     predictor = TableLatencyPredictor(table_file='SD710')
     model_file, param_file = save_cls_model(model,
                                             input_shape=[1, 3, 224, 224],
                                             save_dir="./inference_model",
                                             data_type='fp32')
     latency = predictor.predict(model_file=model_file,
                                 param_file=param_file,
                                 data_type='fp32')
     assert latency > 0
Esempio n. 3
0
 def test_case5(self):
     paddle.disable_static()
     model = mobilenet_v1()
     predictor = TableLatencyPredictor(f'./{opt_tool}',
                                       hardware='845',
                                       threads=4,
                                       power_mode=3,
                                       batchsize=1)
     latency = predictor.predict_latency(model,
                                         input_shape=[1, 3, 224, 224],
                                         save_dir='./model',
                                         data_type='fp32',
                                         task_type='seg')
     assert latency > 0
    def _gen_model(self):
        self.net = mobilenet_v1(num_classes=10, pretrained=False)
        configs = {
            'stable_iterations': 0,
            'pruning_iterations': 1000,
            'tunning_iterations': 1000,
            'resume_iteration': 500,
            'pruning_steps': 20,
            'initial_ratio': 0.05,
        }
        self.pruner = GMPUnstructuredPruner(self.net,
                                            ratio=0.55,
                                            configs=configs)

        self.assertGreater(self.pruner.ratio, 0.3)
Esempio n. 5
0
def compress(args):
    shuffle = True
    if args.ce_test:
        # set seed
        seed = 111
        paddle.seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        args.num_workers = 0
        shuffle = False

    if args.use_gpu:
        place = paddle.set_device('gpu')
    else:
        place = paddle.set_device('cpu')

    trainer_num = paddle.distributed.get_world_size()
    use_data_parallel = trainer_num != 1
    if use_data_parallel:
        dist.init_parallel_env()

    train_reader = None
    test_reader = None
    if args.data == "imagenet":
        import imagenet_reader as reader
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
        class_dim = 1000
    elif args.data == "cifar10":
        normalize = T.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5],
                                data_format='CHW')
        transform = T.Compose([T.Transpose(), normalize])
        train_dataset = paddle.vision.datasets.Cifar10(mode='train',
                                                       backend='cv2',
                                                       transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(mode='test',
                                                     backend='cv2',
                                                     transform=transform)
        class_dim = 10
    else:
        raise ValueError("{} is not supported.".format(args.data))

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=shuffle,
        drop_last=True)

    train_loader = paddle.io.DataLoader(train_dataset,
                                        places=place,
                                        batch_sampler=batch_sampler,
                                        return_list=True,
                                        num_workers=args.num_workers,
                                        use_shared_memory=True)

    valid_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
        drop_last=False,
        return_list=True,
        batch_size=args.batch_size_for_validation,
        shuffle=False,
        use_shared_memory=True)
    step_per_epoch = int(
        np.ceil(len(train_dataset) / args.batch_size / ParallelEnv().nranks))
    # model definition
    model = mobilenet_v1(num_classes=class_dim, pretrained=True)
    if ParallelEnv().nranks > 1:
        model = paddle.DataParallel(model)

    opt, learning_rate = create_optimizer(args, step_per_epoch, model)

    if args.checkpoint is not None and args.last_epoch > -1:
        if args.checkpoint.endswith('pdparams'):
            args.checkpoint = args.checkpoint[:-9]
        if args.checkpoint.endswith('pdopt'):
            args.checkpoint = args.checkpoint[:-6]
        model.set_state_dict(paddle.load(args.checkpoint + ".pdparams"))
        opt.set_state_dict(paddle.load(args.checkpoint + ".pdopt"))
    elif args.pretrained_model is not None:
        if args.pretrained_model.endswith('pdparams'):
            args.pretrained_model = args.pretrained_model[:-9]
        if args.pretrained_model.endswith('pdopt'):
            args.pretrained_model = args.pretrained_model[:-6]
        model.set_state_dict(paddle.load(args.pretrained_model + ".pdparams"))

    if args.pruning_strategy == 'gmp':
        # GMP pruner step 0: define configs. No need to do this if you are not using 'gmp'
        configs = {
            'stable_iterations': args.stable_epochs * step_per_epoch,
            'pruning_iterations': args.pruning_epochs * step_per_epoch,
            'tunning_iterations': args.tunning_epochs * step_per_epoch,
            'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
            'pruning_steps': args.pruning_steps,
            'initial_ratio': args.initial_ratio,
        }
    else:
        configs = None

    # GMP pruner step 1: initialize a pruner object
    pruner = create_unstructured_pruner(model, args, configs=configs)

    def test(epoch):
        model.eval()
        acc_top1_ns = []
        acc_top5_ns = []
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}"
                    .format(epoch, batch_id, np.mean(acc_top1.numpy()),
                            np.mean(acc_top5.numpy()), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))

        _logger.info(
            "Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
                epoch, np.mean(np.array(acc_top1_ns, dtype="object")),
                np.mean(np.array(acc_top5_ns, dtype="object"))))

    def train(epoch):
        model.train()
        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()

        for batch_id, data in enumerate(train_loader):
            train_reader_cost += time.time() - reader_start
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)

            train_start = time.time()
            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)

            loss.backward()
            opt.step()
            learning_rate.step()
            opt.clear_grad()
            # GMP pruner step 2: step() to update ratios and other internal states of the pruner.
            pruner.step()

            train_run_cost += time.time() - train_start
            total_samples += args.batch_size

            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
                    .format(
                        epoch, batch_id, opt.get_lr(), np.mean(loss.numpy()),
                        np.mean(acc_top1.numpy()), np.mean(acc_top5.numpy()),
                        train_reader_cost / args.log_period,
                        (train_reader_cost + train_run_cost) / args.log_period,
                        total_samples / args.log_period,
                        total_samples / (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
            reader_start = time.time()

    for i in range(args.last_epoch + 1, args.num_epochs):
        train(i)
        # GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
        pruner.update_params()

        if (i + 1) % args.test_period == 0:
            _logger.info(
                "The current sparsity of the pruned model is: {}%".format(
                    round(100 * UnstructuredPruner.total_sparse(model), 2)))
            test(i)

        if (i + 1) % args.model_period == 0:
            pruner.update_params()
            paddle.save(model.state_dict(),
                        os.path.join(args.model_path, "model.pdparams"))
            paddle.save(opt.state_dict(),
                        os.path.join(args.model_path, "model.pdopt"))
Esempio n. 6
0
 def setUp(self):
     self.model = mobilenet_v1()
     self.origin_weights = {}
     for name, param in self.model.named_parameters():
         self.origin_weights[name] = param
Esempio n. 7
0
def compress(args):
    if args.use_gpu:
        place = paddle.set_device('gpu')
    else:
        place = paddle.set_device('cpu')

    trainer_num = paddle.distributed.get_world_size()
    use_data_parallel = trainer_num != 1
    if use_data_parallel:
        dist.init_parallel_env()

    train_reader = None
    test_reader = None
    if args.data == "imagenet":
        import imagenet_reader as reader
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
        class_dim = 1000
    elif args.data == "cifar10":
        normalize = T.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5],
                                data_format='CHW')
        transform = T.Compose([T.Transpose(), normalize])
        train_dataset = paddle.vision.datasets.Cifar10(mode='train',
                                                       backend='cv2',
                                                       transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(mode='test',
                                                     backend='cv2',
                                                     transform=transform)
        class_dim = 10
    else:
        raise ValueError("{} is not supported.".format(args.data))

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)

    train_loader = paddle.io.DataLoader(train_dataset,
                                        places=place,
                                        batch_sampler=batch_sampler,
                                        return_list=True,
                                        num_workers=args.num_workers,
                                        use_shared_memory=True)

    valid_loader = paddle.io.DataLoader(
        val_dataset,
        places=place,
        drop_last=False,
        return_list=True,
        batch_size=args.batch_size_for_validation,
        shuffle=False,
        use_shared_memory=True)
    step_per_epoch = int(
        np.ceil(len(train_dataset) / args.batch_size / ParallelEnv().nranks))
    # model definition
    model = mobilenet_v1(num_classes=class_dim, pretrained=True)
    if ParallelEnv().nranks > 1:
        model = paddle.DataParallel(model)

    if args.pretrained_model is not None:
        model.set_state_dict(paddle.load(args.pretrained_model))

    opt, learning_rate = create_optimizer(args, step_per_epoch, model)

    def test(epoch):
        model.eval()
        acc_top1_ns = []
        acc_top5_ns = []
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}"
                    .format(epoch, batch_id, np.mean(acc_top1.numpy()),
                            np.mean(acc_top5.numpy()), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))

        _logger.info(
            "Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
                epoch, np.mean(np.array(acc_top1_ns, dtype="object")),
                np.mean(np.array(acc_top5_ns, dtype="object"))))

    def train(epoch):
        model.train()
        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()

        for batch_id, data in enumerate(train_loader):
            train_reader_cost += time.time() - reader_start
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)

            train_start = time.time()
            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)

            loss.backward()
            opt.step()
            learning_rate.step()
            opt.clear_grad()
            pruner.step()
            train_run_cost += time.time() - train_start
            total_samples += args.batch_size * ParallelEnv().nranks

            if batch_id % args.log_period == 0:
                _logger.info(
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {}; acc_top1: {}; acc_top5: {}; avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
                    .format(
                        epoch, batch_id, opt.get_lr(), np.mean(loss.numpy()),
                        np.mean(acc_top1.numpy()), np.mean(acc_top5.numpy()),
                        train_reader_cost / args.log_period,
                        (train_reader_cost + train_run_cost) / args.log_period,
                        total_samples / args.log_period,
                        total_samples / (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0

            reader_start = time.time()

    pruner = UnstructuredPruner(model,
                                mode=args.pruning_mode,
                                ratio=args.ratio,
                                threshold=args.threshold)

    for i in range(args.resume_epoch + 1, args.num_epochs):
        train(i)
        if (i + 1) % args.test_period == 0:
            pruner.update_params()
            _logger.info(
                "The current density of the pruned model is: {}%".format(
                    round(100 * UnstructuredPruner.total_sparse(model), 2)))
            test(i)
        if (i + 1) % args.model_period == 0:
            pruner.update_params()
            paddle.save(model.state_dict(),
                        os.path.join(args.model_path, "model-pruned.pdparams"))
            paddle.save(opt.state_dict(),
                        os.path.join(args.model_path, "opt-pruned.pdopt"))
Esempio n. 8
0
def compress(args):
    test_reader = None
    if args.data == "imagenet":
        import imagenet_reader as reader
        val_dataset = reader.ImageNetDataset(mode='val')
        class_dim = 1000
    elif args.data == "cifar10":
        normalize = T.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5],
                                data_format='CHW')
        transform = T.Compose([T.Transpose(), normalize])
        val_dataset = paddle.vision.datasets.Cifar10(mode='test',
                                                     backend='cv2',
                                                     transform=transform)
        class_dim = 10
    else:
        raise ValueError("{} is not supported.".format(args.data))

    places = paddle.static.cuda_places(
    ) if args.use_gpu else paddle.static.cpu_places()
    valid_loader = paddle.io.DataLoader(val_dataset,
                                        places=places,
                                        drop_last=False,
                                        return_list=True,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        use_shared_memory=True)

    # model definition
    model = mobilenet_v1(num_classes=class_dim, pretrained=True)

    def test(epoch):
        model.eval()
        acc_top1_ns = []
        acc_top5_ns = []
        for batch_id, data in enumerate(valid_loader):
            start_time = time.time()
            x_data = data[0]
            y_data = paddle.to_tensor(data[1])
            if args.data == 'cifar10':
                y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
            acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)
            end_time = time.time()
            if batch_id % args.log_period == 0:
                _logger.info(
                    "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}"
                    .format(epoch, batch_id, np.mean(acc_top1.numpy()),
                            np.mean(acc_top5.numpy()), end_time - start_time))
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))

        _logger.info(
            "Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
                epoch, np.mean(np.array(acc_top1_ns, dtype="object")),
                np.mean(np.array(acc_top5_ns, dtype="object"))))

    model.set_state_dict(paddle.load(args.pruned_model))
    _logger.info("The current sparsity of the pruned model is: {}%".format(
        round(100 * UnstructuredPruner.total_sparse(model), 2)))
    test(0)
Esempio n. 9
0
 def setUp(self):
     self.model = mobilenet_v1()
Esempio n. 10
0
 def _gen_model(self):
     self.net = mobilenet_v1(num_classes=10, pretrained=False)
     self.pruner = UnstructuredPruner(self.net,
                                      mode='ratio',
                                      ratio=0.98,
                                      threshold=0.0)
Esempio n. 11
0
    latency = predictor.predict_latency(
        model,
        input_shape=[1, 3, 224, 224],
        save_dir='./tmp_model',
        data_type=data_type,
        task_type='cls')
    print('{} latency : {}'.format(data_type, latency))

    subprocess.call('rm -rf ./tmp_model', shell=True)
    paddle.disable_static()
    return latency


if __name__ == '__main__':
    if args.model == 'mobilenet_v1':
        model = mobilenet_v1()
    elif args.model == 'mobilenet_v2':
        model = mobilenet_v2()
    else:
        assert False, f'model should be mobilenet_v1 or mobilenet_v2'

    latency = get_latency(model, args.data_type)

    if args.model == 'mobilenet_v1' and args.data_type == 'fp32':
        assert latency == 41.92806607483133
    elif args.model == 'mobilenet_v1' and args.data_type == 'int8':
        assert latency == 36.64814722993898
    elif args.model == 'mobilenet_v2' and args.data_type == 'fp32':
        assert latency == 27.847896889217566
    elif args.model == 'mobilenet_v2' and args.data_type == 'int8':
        assert latency == 23.967800360138803
Esempio n. 12
0
def compress(args):
    if args.data == "cifar10":
        transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
        train_dataset = paddle.vision.datasets.Cifar10(mode="train",
                                                       backend="cv2",
                                                       transform=transform)
        val_dataset = paddle.vision.datasets.Cifar10(mode="test",
                                                     backend="cv2",
                                                     transform=transform)
        class_dim = 10
        image_shape = [3, 32, 32]
        pretrain = False
        args.total_images = 50000
    elif args.data == "imagenet":
        import imagenet_reader as reader
        train_dataset = reader.ImageNetDataset(mode='train')
        val_dataset = reader.ImageNetDataset(mode='val')
        class_dim = 1000
        image_shape = "3,224,224"
    else:
        raise ValueError("{} is not supported.".format(args.data))

    trainer_num = paddle.distributed.get_world_size()
    use_data_parallel = trainer_num != 1

    place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
    # model definition
    if use_data_parallel:
        paddle.distributed.init_parallel_env()

    pretrain = True if args.data == "imagenet" else False
    if args.model == "mobilenet_v1":
        net = mobilenet_v1(pretrained=pretrain, num_classes=class_dim)
    elif args.model == "mobilenet_v3":
        net = MobileNetV3_large_x1_0(class_dim=class_dim)
        if pretrain:
            load_dygraph_pretrain(net, args.pretrained_model, True)
    else:
        raise ValueError("{} is not supported.".format(args.model))
    _logger.info("Origin model summary:")
    paddle.summary(net, (1, 3, 224, 224))

    ############################################################################################################
    # 1. quantization configs
    ############################################################################################################
    quant_config = {
        # weight preprocess type, default is None and no preprocessing is performed.
        'weight_preprocess_type': None,
        # activation preprocess type, default is None and no preprocessing is performed.
        'activation_preprocess_type': None,
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. default is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
        # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
        'quantizable_layer_type': ['Conv2D', 'Linear'],
    }

    if args.use_pact:
        quant_config['activation_preprocess_type'] = 'PACT'

    ############################################################################################################
    # 2. Quantize the model with QAT (quant aware training)
    ############################################################################################################

    quanter = QAT(config=quant_config)
    quanter.quantize(net)

    _logger.info("QAT model summary:")
    paddle.summary(net, (1, 3, 224, 224))

    opt, lr = create_optimizer(net, trainer_num, args)

    if use_data_parallel:
        net = paddle.DataParallel(net)

    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True)
    train_loader = paddle.io.DataLoader(train_dataset,
                                        batch_sampler=train_batch_sampler,
                                        places=place,
                                        return_list=True,
                                        num_workers=4)

    valid_loader = paddle.io.DataLoader(val_dataset,
                                        places=place,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        drop_last=False,
                                        return_list=True,
                                        num_workers=4)

    @paddle.no_grad()
    def test(epoch, net):
        net.eval()
        batch_id = 0
        acc_top1_ns = []
        acc_top5_ns = []

        eval_reader_cost = 0.0
        eval_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()
        for data in valid_loader():
            eval_reader_cost += time.time() - reader_start
            image = data[0]
            label = data[1]
            if args.data == "cifar10":
                label = paddle.reshape(label, [-1, 1])

            eval_start = time.time()

            out = net(image)
            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)

            eval_run_cost += time.time() - eval_start
            batch_size = image.shape[0]
            total_samples += batch_size

            if batch_id % args.log_period == 0:
                log_period = 1 if batch_id == 0 else args.log_period
                _logger.info(
                    "Eval epoch[{}] batch[{}] - top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s"
                    .format(epoch, batch_id, np.mean(acc_top1.numpy()),
                            np.mean(acc_top5.numpy()),
                            eval_reader_cost / log_period,
                            (eval_reader_cost + eval_run_cost) / log_period,
                            total_samples / log_period, total_samples /
                            (eval_reader_cost + eval_run_cost)))
                eval_reader_cost = 0.0
                eval_run_cost = 0.0
                total_samples = 0
            acc_top1_ns.append(np.mean(acc_top1.numpy()))
            acc_top5_ns.append(np.mean(acc_top5.numpy()))
            batch_id += 1
            reader_start = time.time()

        _logger.info(
            "Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
                epoch, np.mean(np.array(acc_top1_ns)),
                np.mean(np.array(acc_top5_ns))))
        return np.mean(np.array(acc_top1_ns))

    def cross_entropy(input, target, ls_epsilon):
        if ls_epsilon > 0:
            if target.shape[-1] != class_dim:
                target = paddle.nn.functional.one_hot(target, class_dim)
            target = paddle.nn.functional.label_smooth(target,
                                                       epsilon=ls_epsilon)
            target = paddle.reshape(target, shape=[-1, class_dim])
            input = -paddle.nn.functional.log_softmax(input, axis=-1)
            cost = paddle.sum(target * input, axis=-1)
        else:
            cost = paddle.nn.functional.cross_entropy(input=input,
                                                      label=target)
        avg_cost = paddle.mean(cost)
        return avg_cost

    def train(epoch, net):

        net.train()
        batch_id = 0

        train_reader_cost = 0.0
        train_run_cost = 0.0
        total_samples = 0
        reader_start = time.time()
        for data in train_loader():
            train_reader_cost += time.time() - reader_start

            image = data[0]
            label = data[1]
            if args.data == "cifar10":
                label = paddle.reshape(label, [-1, 1])

            train_start = time.time()
            out = net(image)
            avg_cost = cross_entropy(out, label, args.ls_epsilon)

            acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
            acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
            avg_cost.backward()
            opt.step()
            opt.clear_grad()
            lr.step()

            loss_n = np.mean(avg_cost.numpy())
            acc_top1_n = np.mean(acc_top1.numpy())
            acc_top5_n = np.mean(acc_top5.numpy())

            train_run_cost += time.time() - train_start
            batch_size = image.shape[0]
            total_samples += batch_size

            if batch_id % args.log_period == 0:
                log_period = 1 if batch_id == 0 else args.log_period
                _logger.info(
                    "epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s"
                    .format(
                        epoch, batch_id, lr.get_lr(), loss_n, acc_top1_n,
                        acc_top5_n, train_reader_cost / log_period,
                        (train_reader_cost + train_run_cost) / log_period,
                        total_samples / log_period,
                        total_samples / (train_reader_cost + train_run_cost)))
                train_reader_cost = 0.0
                train_run_cost = 0.0
                total_samples = 0
            batch_id += 1
            reader_start = time.time()

    ############################################################################################################
    # train loop
    ############################################################################################################
    best_acc1 = 0.0
    best_epoch = 0
    for i in range(args.num_epochs):
        train(i, net)
        acc1 = test(i, net)
        if paddle.distributed.get_rank() == 0:
            model_prefix = os.path.join(args.model_save_dir, "epoch_" + str(i))
            paddle.save(net.state_dict(), model_prefix + ".pdparams")
            paddle.save(opt.state_dict(), model_prefix + ".pdopt")

        if acc1 > best_acc1:
            best_acc1 = acc1
            best_epoch = i
            if paddle.distributed.get_rank() == 0:
                model_prefix = os.path.join(args.model_save_dir, "best_model")
                paddle.save(net.state_dict(), model_prefix + ".pdparams")
                paddle.save(opt.state_dict(), model_prefix + ".pdopt")

    ############################################################################################################
    # 3. Save quant aware model
    ############################################################################################################
    if paddle.distributed.get_rank() == 0:
        # load best model
        load_dygraph_pretrain(net,
                              os.path.join(args.model_save_dir, "best_model"))

        path = os.path.join(args.model_save_dir, "inference_model",
                            'qat_model')
        quanter.save_quantized_model(net,
                                     path,
                                     input_spec=[
                                         paddle.static.InputSpec(
                                             shape=[None, 3, 224, 224],
                                             dtype='float32')
                                     ])