示例#1
0
def init_model(config, net, optimizer=None):
    """
    load model from checkpoint or pretrained_model
    """
    checkpoints = config.get('checkpoints')
    if checkpoints and optimizer is not None:
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
        para_dict = paddle.load(checkpoints + ".pdparams")
        opti_dict = paddle.load(checkpoints + ".pdopt")
        net.set_dict(para_dict)
        optimizer.set_state_dict(opti_dict)
        logger.info("Finish load checkpoints from {}".format(checkpoints))
        return

    pretrained_model = config.get('pretrained_model')
    load_static_weights = config.get('load_static_weights', False)
    use_distillation = config.get('use_distillation', False)
    if pretrained_model:
        if use_distillation:
            load_distillation_model(net, pretrained_model, load_static_weights)
        else:  # common load
            load_dygraph_pretrain(net,
                                  path=pretrained_model,
                                  load_static_weights=load_static_weights)
            logger.info(
                logger.coloring(
                    "Finish load pretrained model from {}".format(
                        pretrained_model), "HEADER"))
示例#2
0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('cost', ':6.3f')
    tic = time.time()
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
    for idx, batch in enumerate(dataloader()):
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        tic = time.time()
        for i, m in enumerate(metrics):
            metric_list[i].update(m[0], len(batch[0]))
        fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
        if trainer_id == 0:

            logger.info("[epoch:%3d][%s][step:%4d]%s" %
                        (epoch, mode, idx, fetchs_str))
    if trainer_id == 0:
        logger.info("END [epoch:%3d][%s]%s" % (epoch, mode, fetchs_str))
示例#3
0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('cost', '.3f')
    tic = time.time()
    for idx, batch in enumerate(dataloader()):
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        tic = time.time()
        for i, m in enumerate(metrics):
            metric_list[i].update(m[0], len(batch[0]))
        fetchs_str = ''.join([m.value
                              for m in metric_list] + [batch_time.value])
        logger.info("[epoch:{:3d}][{:s}][step:{:4d}]{:s}".format(
            epoch, mode, idx, fetchs_str))
    end_str = ''.join([m.mean for m in metric_list] + [batch_time.total])
    logger.info("END [epoch:{:3d}][{:s}]{:s}".format(epoch, mode, end_str))
示例#4
0
def log_info(trainer, batch_size, epoch_id, iter_id):
    lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr())
    metric_msg = ", ".join([
        "{}: {:.5f}".format(key, trainer.output_info[key].avg)
        for key in trainer.output_info
    ])
    time_msg = "s, ".join([
        "{}: {:.5f}".format(key, trainer.time_info[key].avg)
        for key in trainer.time_info
    ])

    ips_msg = "ips: {:.5f} images/sec".format(
        batch_size / trainer.time_info["batch_cost"].avg)
    eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
                ) * len(trainer.train_dataloader) - iter_id
               ) * trainer.time_info["batch_cost"].avg
    eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
    logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
        epoch_id, trainer.config["Global"]["epochs"], iter_id,
        len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
        eta_msg))

    logger.scaler(
        name="lr",
        value=trainer.lr_sch.get_lr(),
        step=trainer.global_step,
        writer=trainer.vdl_writer)
    for key in trainer.output_info:
        logger.scaler(
            name="train_{}".format(key),
            value=trainer.output_info[key].avg,
            step=trainer.global_step,
            writer=trainer.vdl_writer)
示例#5
0
def _decompress(fname):
    """
    Decompress for zip and tar file
    """
    logger.info("Decompressing {}...".format(fname))

    # For protecting decompressing interupted,
    # decompress to fpath_tmp directory firstly, if decompress
    # successed, move decompress files to fpath and delete
    # fpath_tmp and remove download compress file.
    fpath = os.path.split(fname)[0]
    fpath_tmp = os.path.join(fpath, 'tmp')
    if os.path.isdir(fpath_tmp):
        shutil.rmtree(fpath_tmp)
        os.makedirs(fpath_tmp)

    if fname.find('tar') >= 0:
        with tarfile.open(fname) as tf:
            tf.extractall(path=fpath_tmp)
    elif fname.find('zip') >= 0:
        with zipfile.ZipFile(fname) as zf:
            zf.extractall(path=fpath_tmp)
    else:
        raise TypeError("Unsupport compress file type {}".format(fname))

    for f in os.listdir(fpath_tmp):
        src_dir = os.path.join(fpath_tmp, f)
        dst_dir = os.path.join(fpath, f)
        _move_and_merge_tree(src_dir, dst_dir)

    shutil.rmtree(fpath_tmp)
    os.remove(fname)
示例#6
0
def term_mp(sig_num, frame):
    """ kill all child processes
    """
    pid = os.getpid()
    pgid = os.getpgid(os.getpid())
    logger.info("main proc {} exit, kill process group "
                "{}".format(pid, pgid))
    os.killpg(pgid, signal.SIGKILL)
示例#7
0
def save_model(program, model_path, epoch_id, prefix='ppcls'):
    """
    save model to the target path
    """
    model_path = os.path.join(model_path, str(epoch_id))
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)
    fluid.save(program, model_prefix)
    logger.info("Already save model in {}".format(model_path))
示例#8
0
def get(architecture, path, decompress=True):
    """
    Get the pretrained model.
    """
    _check_pretrained_name(architecture)
    url = _get_url(architecture)
    fname = _download(url, path)
    if decompress: _decompress(fname)
    logger.info("download {} finished ".format(fname))
示例#9
0
def get(architecture, path, decompress=False, postfix="pdparams"):
    """
    Get the pretrained model.
    """
    _check_pretrained_name(architecture)
    url = _get_url(architecture, postfix=postfix)
    fname = _download(url, path)
    if postfix == "tar" and decompress:
        _decompress(fname)
    logger.info("download {} finished ".format(fname))
示例#10
0
def _save_student_model(net, model_prefix):
    """
    save student model if the net is the network contains student
    """
    student_model_prefix = model_prefix + "_student.pdparams"
    if hasattr(net, "_layers"):
        net = net._layers
    if hasattr(net, "student"):
        paddle.save(net.student.state_dict(), student_model_prefix)
        logger.info(
            "Already save student model in {}".format(student_model_prefix))
示例#11
0
def save_model(program, model_path, epoch_id, prefix='ppcls'):
    """
    save model to the target path
    """
    model_path = os.path.join(model_path, str(epoch_id))
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)
    paddle.static.save(program, model_prefix)
    logger.info(
        logger.coloring("Already save model in {}".format(model_path),
                        "HEADER"))
示例#12
0
def _download(url, path, md5sum=None):
    """
    Download from url, save to path.

    url (str): download url
    path (str): download to given path
    """
    if not osp.exists(path):
        os.makedirs(path)

    fname = osp.split(url)[-1]
    fullname = osp.join(path, fname)
    retry_cnt = 0

    while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
        if retry_cnt < DOWNLOAD_RETRY_LIMIT:
            retry_cnt += 1
        else:
            raise RuntimeError("Download from {} failed. "
                               "Retry limit reached".format(url))

        logger.info("Downloading {} from {}".format(fname, url))

        try:
            req = requests.get(url, stream=True)
        except Exception as e:  # requests.exceptions.ConnectionError
            logger.info(
                "Downloading {} from {} failed {} times with exception {}".
                format(fname, url, retry_cnt + 1, str(e)))
            time.sleep(1)
            continue

        if req.status_code != 200:
            raise RuntimeError("Downloading from {} failed with code "
                               "{}!".format(url, req.status_code))

        # For protecting download interupted, download to
        # tmp_fullname firstly, move tmp_fullname to fullname
        # after download finished
        tmp_fullname = fullname + "_tmp"
        total_size = req.headers.get('content-length')
        with open(tmp_fullname, 'wb') as f:
            if total_size:
                with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
                    for chunk in req.iter_content(chunk_size=1024):
                        f.write(chunk)
                        pbar.update(1)
            else:
                for chunk in req.iter_content(chunk_size=1024):
                    if chunk:
                        f.write(chunk)
        shutil.move(tmp_fullname, fullname)

    return fullname
示例#13
0
def save_model(program, model_path, epoch_id, prefix='ppcls'):
    """
    save model to the target path
    """
    if paddle.distributed.get_rank() != 0:
        return
    model_path = os.path.join(model_path, str(epoch_id))
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)
    paddle.static.save(program, model_prefix)
    logger.info("Already save model in {}".format(model_path))
示例#14
0
def apply_to_static(config, model):
    support_to_static = config['Global'].get('to_static', False)

    if support_to_static:
        specs = None
        if 'image_shape' in config['Global']:
            specs = [InputSpec([None] + config['Global']['image_shape'])]
        model = to_static(model, input_spec=specs)
        logger.info("Successfully to apply @to_static with specs: {}".format(
            specs))
    return model
示例#15
0
def main(args):
    benchmark_file_list = args.benchmark_file_list
    model_infos = parse_model_infos(benchmark_file_list)
    right_models = []
    wrong_models = []

    for model_info in model_infos:
        try:
            pretrained_url = model_info["pretrain_path"]
            fname = _download(pretrained_url, args.pretrained_dir)
            pretrained_path = os.path.splitext(fname)[0]
            if pretrained_url.endswith("tar"):
                path = _decompress(fname)
                pretrained_path = os.path.join(
                    os.path.dirname(pretrained_path), path)

            args.config = model_info["config_path"]
            args.override = [
                "pretrained_model={}".format(pretrained_path),
                "VALID.batch_size=256",
                "VALID.num_workers=16",
                "load_static_weights=True",
                "print_interval=100",
            ]

            manager = Manager()
            return_dict = manager.dict()

            # A hack method to avoid name conflict.
            # Multi-process maybe a better method here.
            # More details can be seen in branch 2.0-beta.
            # TODO: fluid needs to be removed in the future.
            with paddle.utils.unique_name.guard():
                eval.main(args, return_dict)

            top1_acc = return_dict.get("top1_acc", 0.0)
        except Exception as e:
            logger.error(e)
            top1_acc = 0.0
        diff = abs(top1_acc - model_info["top1_acc"])
        if diff > 0.001:
            err_info = "[{}]Top-1 acc diff should be <= 0.001 but got diff {}, gt acc: {}, eval acc: {}".format(
                model_info["model_name"], diff, model_info["top1_acc"],
                top1_acc)
            logger.warning(err_info)
            wrong_models.append(model_info["model_name"])
        else:
            right_models.append(model_info["model_name"])

    logger.info("[number of right models: {}, they are: {}".format(
        len(right_models), right_models))
    logger.info("[number of wrong models: {}, they are: {}".format(
        len(wrong_models), wrong_models))
示例#16
0
def _get_unique_endpoints(trainer_endpoints):
    # Sorting is to avoid different environmental variables for each card
    trainer_endpoints.sort()
    ips = set()
    unique_endpoints = set()
    for endpoint in trainer_endpoints:
        ip = endpoint.split(":")[0]
        if ip in ips:
            continue
        ips.add(ip)
        unique_endpoints.add(endpoint)
    logger.info("unique_endpoints {}".format(unique_endpoints))
    return unique_endpoints
示例#17
0
def quantize_model(config, model):
    if config.get("Slim", False) and config["Slim"].get("quant", False):
        from paddleslim.dygraph.quant import QAT
        assert config["Slim"]["quant"]["name"].lower(
        ) == 'pact', 'Only PACT quantization method is supported now'
        QUANT_CONFIG["activation_preprocess_type"] = "PACT"
        model.quanter = QAT(config=QUANT_CONFIG)
        model.quanter.quantize(model)
        logger.info("QAT model summary:")
        paddle.summary(model, (1, 3, 224, 224))
    else:
        model.quanter = None
    return
示例#18
0
def init_model(config, program, exe):
    """
    load model from checkpoint or pretrained_model
    """
    checkpoints = config.get('checkpoints')
    if checkpoints:
        fluid.load(program, checkpoints, exe)
        logger.info("Finish initing model from {}".format(checkpoints))
        return

    pretrained_model = config.get('pretrained_model')
    if pretrained_model:
        load_params(exe, program, pretrained_model)
        logger.info("Finish initing model from {}".format(pretrained_model))
示例#19
0
def get_gpu_count():
    """get avaliable gpu count

    Returns:
        gpu_count: int
    """

    gpu_count = 0

    env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if env_cuda_devices is not None:
        assert isinstance(env_cuda_devices, str)
        try:
            if not env_cuda_devices:
                return 0
            gpu_count = len(
                [x for x in env_cuda_devices.split(',') if int(x) >= 0])
            logger.info(
                'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
        except:
            logger.info('Cannot find available GPU devices, using CPU now.')
            gpu_count = 0
    else:
        try:
            gpu_count = str(subprocess.check_output(["nvidia-smi",
                                                     "-L"])).count('UUID')
            logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
        except:
            logger.info('Cannot find available GPU devices, using CPU now.')
            gpu_count = 0
    return gpu_count
示例#20
0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('elapse', '.3f')
    tic = time.time()
    for idx, batch in enumerate(dataloader()):
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        tic = time.time()
        for i, m in enumerate(metrics):
            metric_list[i].update(m[0], len(batch[0]))
        fetchs_str = ''.join([str(m.value) + ' '
                              for m in metric_list] + [batch_time.value]) + 's'
        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'eval':
            logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            logger.info("{:s} {:s} {:s}".format(
                logger.coloring(epoch_str, "HEADER")
                if idx == 0 else epoch_str,
                logger.coloring(step_str, "PURPLE"),
                logger.coloring(fetchs_str, 'OKGREEN')))

    end_str = ''.join([str(m.mean) + ' '
                       for m in metric_list] + [batch_time.total]) + 's'
    if mode == 'eval':
        logger.info("END {:s} {:s}s".format(mode, end_str))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)

        logger.info("{:s} {:s} {:s}".format(
            logger.coloring(end_epoch_str, "RED"),
            logger.coloring(mode, "PURPLE"),
            logger.coloring(end_str, "OKGREEN")))

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg
示例#21
0
def load_params(exe, prog, path, ignore_params=None):
    """
    Load model from the given path.
    Args:
        exe (fluid.Executor): The fluid.Executor object.
        prog (fluid.Program): load weight to which Program object.
        path (string): URL string or loca model path.
        ignore_params (list): ignore variable to load when finetuning.
            It can be specified by finetune_exclude_pretrained_params
            and the usage can refer to the document
            docs/advanced_tutorials/TRANSFER_LEARNING.md
    """
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))

    logger.info(
        logger.coloring('Loading parameters from {}...'.format(path),
                        'HEADER'))

    ignore_set = set()
    state = _load_state(path)

    # ignore the parameter which mismatch the shape
    # between the model and pretrain weight.
    all_var_shape = {}
    for block in prog.blocks:
        for param in block.all_parameters():
            all_var_shape[param.name] = param.shape
    ignore_set.update([
        name for name, shape in all_var_shape.items()
        if name in state and shape != state[name].shape
    ])

    if ignore_params:
        all_var_names = [var.name for var in prog.list_vars()]
        ignore_list = filter(
            lambda var: any([re.match(name, var) for name in ignore_params]),
            all_var_names)
        ignore_set.update(list(ignore_list))

    if len(ignore_set) > 0:
        for k in ignore_set:
            if k in state:
                logger.warning(
                    'variable {} is already excluded automatically'.format(k))
                del state[k]

    paddle.static.set_program_state(prog, state)
示例#22
0
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
    """
    save model to the target path
    """
    if paddle.distributed.get_rank() != 0:
        return
    model_path = os.path.join(model_path, str(epoch_id))
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)

    _save_student_model(net, model_prefix)

    paddle.save(net.state_dict(), model_prefix + ".pdparams")
    paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
    logger.info("Already save model in {}".format(model_path))
示例#23
0
def _md5check(fullname, md5sum=None):
    if md5sum is None:
        return True

    logger.info("File {} md5 checking...".format(fullname))
    md5 = hashlib.md5()
    with open(fullname, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), b""):
            md5.update(chunk)
    calc_md5sum = md5.hexdigest()

    if calc_md5sum != md5sum:
        logger.info("File {} md5 check failed, {}(calc) != "
                    "{}(base)".format(fullname, calc_md5sum, md5sum))
        return False
    return True
示例#24
0
def init_model(config, program, exe):
    """
    load model from checkpoint or pretrained_model
    """
    checkpoints = config.get('checkpoints')
    if checkpoints:
        paddle.static.load(program, checkpoints, exe)
        logger.info("Finish initing model from {}".format(checkpoints))
        return

    pretrained_model = config.get('pretrained_model')
    if pretrained_model:
        if not isinstance(pretrained_model, list):
            pretrained_model = [pretrained_model]
        for pretrain in pretrained_model:
            load_params(exe, program, pretrain)
        logger.info("Finish initing model from {}".format(pretrained_model))
示例#25
0
def parse_model_infos(benchmark_file_list):
    model_infos = []
    with open(benchmark_file_list, "r") as fin:
        lines = fin.readlines()
        for idx, line in enumerate(lines):
            strs = line.strip("\n").strip("\r").split(" ")
            if len(strs) != 4:
                logger.info(
                    "line {0}(info: {1}) format wrong, it should be splited into 4 parts, but got {2}"
                    .format(idx, line, len(strs)))
            model_infos.append({
                "top1_acc": float(strs[0]),
                "model_name": strs[1],
                "config_path": strs[2],
                "pretrain_path": strs[3],
            })
    return model_infos
示例#26
0
def get_path_from_url(url,
                      root_dir,
                      md5sum=None,
                      check_exist=True,
                      decompress=True):
    """ Download from given url to root_dir.
    if file or directory specified by url is exists under
    root_dir, return the path directly, otherwise download
    from url and decompress it, return the path.

    Args:
        url (str): download url
        root_dir (str): root dir for downloading, it should be
                        WEIGHTS_HOME or DATASET_HOME
        md5sum (str): md5 sum of download package
    
    Returns:
        str: a local path to save downloaded models & weights & datasets.
    """

    from paddle.fluid.dygraph.parallel import ParallelEnv

    assert is_url(url), "downloading from {} not a url".format(url)
    # parse path after download to decompress under root_dir
    fullpath = _map_path(url, root_dir)
    # Mainly used to solve the problem of downloading data from different
    # machines in the case of multiple machines. Different ips will download
    # data, and the same ip will only download data once.
    unique_endpoints = _get_unique_endpoints(
        ParallelEnv().trainer_endpoints[:])
    if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
        logger.info("Found {}".format(fullpath))
    else:
        if ParallelEnv().current_endpoint in unique_endpoints:
            fullpath = _download(url, root_dir, md5sum)
        else:
            while not os.path.exists(fullpath):
                time.sleep(1)

    if ParallelEnv().current_endpoint in unique_endpoints:
        if decompress and (tarfile.is_tarfile(fullpath)
                           or zipfile.is_zipfile(fullpath)):
            fullpath = _decompress(fullpath)

    return fullpath
示例#27
0
 def __iter__(self):
     while self.iter_counter < self.length:
         batch = []
         for i, iter_i in enumerate(self.iter_list):
             batch_i = next(iter_i, None)
             if batch_i is None:
                 iter_i = iter(self.sampler_list[i])
                 self.iter_list[i] = iter_i
                 batch_i = next(iter_i, None)
                 assert batch_i is not None, "dataset {} return None".format(
                     i)
             batch += [idx + self.start_list[i] for idx in batch_i]
         if len(batch) == self.batch_size:
             self.iter_counter += 1
             yield batch
         else:
             logger.info("Some dataset reaches end")
     self.iter_counter = 0
示例#28
0
def _download(url, path):
    """
    Download from url, save to path.
    url (str): download url
    path (str): download to given path
    """
    if not os.path.exists(path):
        os.makedirs(path)

    fname = os.path.split(url)[-1]
    fullname = os.path.join(path, fname)
    retry_cnt = 0

    while not os.path.exists(fullname):
        if retry_cnt < DOWNLOAD_RETRY_LIMIT:
            retry_cnt += 1
        else:
            raise RetryError(url, DOWNLOAD_RETRY_LIMIT)

        logger.info("Downloading {} from {}".format(fname, url))

        req = requests.get(url, stream=True)
        if req.status_code != 200:
            raise UrlError(url, req.status_code)

        # For protecting download interupted, download to
        # tmp_fullname firstly, move tmp_fullname to fullname
        # after download finished
        tmp_fullname = fullname + "_tmp"
        total_size = req.headers.get('content-length')
        with open(tmp_fullname, 'wb') as f:
            if total_size:
                for chunk in tqdm.tqdm(
                        req.iter_content(chunk_size=1024),
                        total=(int(total_size) + 1023) // 1024,
                        unit='KB'):
                    f.write(chunk)
            else:
                for chunk in req.iter_content(chunk_size=1024):
                    if chunk:
                        f.write(chunk)
        shutil.move(tmp_fullname, fullname)

    return fullname
示例#29
0
def create_strategy(config):
    """
    Create build strategy and exec strategy.

    Args:
        config(dict): config

    Returns:
        build_strategy: build strategy
        exec_strategy: exec strategy
    """
    build_strategy = paddle.static.BuildStrategy()
    exec_strategy = paddle.static.ExecutionStrategy()

    exec_strategy.num_threads = 1
    exec_strategy.num_iteration_per_drop_scope = (
        10000
        if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)

    fuse_op = True if 'AMP' in config else False

    fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
    fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
    fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
    enable_addto = config.get('enable_addto', fuse_op)

    try:
        build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle version 1.7.0 or higher is "
            "required when you want to fuse batch_norm and activation_op.")

    try:
        build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle version 1.7.0 or higher is "
            "required when you want to fuse elewise_add_act and activation_op."
        )

    try:
        build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
    except Exception as e:
        logger.info(
            "PaddlePaddle 2.0-rc or higher is "
            "required when you want to enable fuse_bn_add_act_ops strategy.")

    try:
        build_strategy.enable_addto = enable_addto
    except Exception as e:
        logger.info("PaddlePaddle 2.0-rc or higher is "
                    "required when you want to enable addto strategy.")
    return build_strategy, exec_strategy
示例#30
0
 def __init__(self,
              dataset,
              batch_size,
              sample_per_id,
              shuffle=True,
              drop_last=True,
              sample_method="sample_avg_prob"):
     super().__init__(dataset,
                      batch_size,
                      shuffle=shuffle,
                      drop_last=drop_last)
     assert batch_size % sample_per_id == 0, \
         "PKSampler configs error, Sample_per_id must be a divisor of batch_size."
     assert hasattr(self.dataset,
                    "labels"), "Dataset must have labels attribute."
     self.sample_per_label = sample_per_id
     self.label_dict = defaultdict(list)
     self.sample_method = sample_method
     for idx, label in enumerate(self.dataset.labels):
         self.label_dict[label].append(idx)
     self.label_list = list(self.label_dict)
     assert len(self.label_list) * self.sample_per_label > self.batch_size, \
         "batch size should be smaller than "
     if self.sample_method == "id_avg_prob":
         self.prob_list = np.array([1 / len(self.label_list)] *
                                   len(self.label_list))
     elif self.sample_method == "sample_avg_prob":
         counter = []
         for label_i in self.label_list:
             counter.append(len(self.label_dict[label_i]))
         self.prob_list = np.array(counter) / sum(counter)
     else:
         logger.error(
             "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
             "but receive {}.".format(self.sample_method))
     diff = np.abs(sum(self.prob_list) - 1)
     if diff > 0.00000001:
         self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
         if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
             logger.error("PKSampler prob list error")
         else:
             logger.info(
                 "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob"
                 .format(diff))