Exemplo n.º 1
0
def log_info(trainer, batch_size, epoch_id, iter_id):
    lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr())
    metric_msg = ", ".join([
        "{}: {:.5f}".format(key, trainer.output_info[key].avg)
        for key in trainer.output_info
    ])
    time_msg = "s, ".join([
        "{}: {:.5f}".format(key, trainer.time_info[key].avg)
        for key in trainer.time_info
    ])

    ips_msg = "ips: {:.5f} images/sec".format(
        batch_size / trainer.time_info["batch_cost"].avg)
    eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
                ) * len(trainer.train_dataloader) - iter_id
               ) * trainer.time_info["batch_cost"].avg
    eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
    logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
        epoch_id, trainer.config["Global"]["epochs"], iter_id,
        len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
        eta_msg))

    logger.scaler(
        name="lr",
        value=trainer.lr_sch.get_lr(),
        step=trainer.global_step,
        writer=trainer.vdl_writer)
    for key in trainer.output_info:
        logger.scaler(
            name="train_{}".format(key),
            value=trainer.output_info[key].avg,
            step=trainer.global_step,
            writer=trainer.vdl_writer)
Exemplo n.º 2
0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('elapse', '.3f')
    tic = time.time()
    for idx, batch in enumerate(dataloader()):
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        tic = time.time()
        for i, m in enumerate(metrics):
            metric_list[i].update(m[0], len(batch[0]))
        fetchs_str = ''.join([str(m.value) + ' '
                              for m in metric_list] + [batch_time.value]) + 's'
        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'eval':
            logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            logger.info("{:s} {:s} {:s}".format(
                logger.coloring(epoch_str, "HEADER")
                if idx == 0 else epoch_str,
                logger.coloring(step_str, "PURPLE"),
                logger.coloring(fetchs_str, 'OKGREEN')))

    end_str = ''.join([str(m.mean) + ' '
                       for m in metric_list] + [batch_time.total]) + 's'
    if mode == 'eval':
        logger.info("END {:s} {:s}s".format(mode, end_str))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)

        logger.info("{:s} {:s} {:s}".format(
            logger.coloring(end_epoch_str, "RED"),
            logger.coloring(mode, "PURPLE"),
            logger.coloring(end_str, "OKGREEN")))

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg
Exemplo n.º 3
0
def run(dataloader,
        config,
        net,
        optimizer=None,
        lr_scheduler=None,
        epoch=0,
        mode='train',
        vdl_writer=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    print_interval = config.get("print_interval", 10)
    use_mix = config.get("use_mix", False) and mode == "train"
    multilabel = config.get("multilabel", False)
    classes_num = config.get("classes_num")

    metric_list = [
        ("loss", AverageMeter(
            'loss', '7.5f', postfix=",")),
        ("lr", AverageMeter(
            'lr', 'f', postfix=",", need_avg=False)),
        ("batch_time", AverageMeter(
            'batch_cost', '.5f', postfix=" s,")),
        ("reader_time", AverageMeter(
            'reader_cost', '.5f', postfix=" s,")),
    ]
    if not use_mix:
        if not multilabel:
            topk_name = 'top{}'.format(config.topk)
            metric_list.insert(
                0, (topk_name, AverageMeter(
                    topk_name, '.5f', postfix=",")))
            metric_list.insert(
                0, ("top1", AverageMeter(
                    "top1", '.5f', postfix=",")))
        else:
            metric_list.insert(
                0, ("multilabel_accuracy", AverageMeter(
                    "multilabel_accuracy", '.5f', postfix=",")))
            metric_list.insert(
                0, ("hamming_distance", AverageMeter(
                    "hamming_distance", '.5f', postfix=",")))

    metric_list = OrderedDict(metric_list)

    tic = time.time()
    for idx, batch in enumerate(dataloader()):
        # avoid statistics from warmup time
        if idx == 10:
            metric_list["batch_time"].reset()
            metric_list["reader_time"].reset()

        metric_list['reader_time'].update(time.time() - tic)
        batch_size = len(batch[0])
        feeds = create_feeds(batch, use_mix, classes_num, multilabel)
        fetchs = create_fetchs(feeds, net, config, mode)
        if mode == 'train':
            avg_loss = fetchs['loss']
            avg_loss.backward()

            optimizer.step()
            optimizer.clear_grad()
            lr_value = optimizer._global_learning_rate().numpy()[0]
            metric_list['lr'].update(lr_value, batch_size)

            if lr_scheduler is not None:
                if lr_scheduler.update_specified:
                    curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                    update = max(
                        0, curr_global_counter - lr_scheduler.update_start_step
                    ) % lr_scheduler.update_step_interval == 0
                    if update:
                        lr_scheduler.step()
                else:
                    lr_scheduler.step()

        for name, fetch in fetchs.items():
            metric_list[name].update(fetch.numpy()[0], batch_size)
        metric_list["batch_time"].update(time.time() - tic)
        tic = time.time()

        if vdl_writer and mode == "train":
            global total_step
            logger.scaler(
                name="lr", value=lr_value, step=total_step, writer=vdl_writer)
            for name, fetch in fetchs.items():
                logger.scaler(
                    name="train_{}".format(name),
                    value=fetch.numpy()[0],
                    step=total_step,
                    writer=vdl_writer)
            total_step += 1

        fetchs_str = ' '.join([
            str(metric_list[key].mean)
            if "time" in key else str(metric_list[key].value)
            for key in metric_list
        ])

        if idx % print_interval == 0:
            ips_info = "ips: {:.5f} images/sec".format(
                batch_size / metric_list["batch_time"].avg)

            if mode == "train":
                epoch_str = "epoch:{:<3d}".format(epoch)
                step_str = "{:s} step:{:<4d}".format(mode, idx)
                eta_sec = ((config["epochs"] - epoch) * len(dataloader) - idx
                           ) * metric_list["batch_time"].avg
                eta_str = "eta: {:s}".format(
                    str(datetime.timedelta(seconds=int(eta_sec))))
                logger.info("{:s}, {:s}, {:s} {:s}, {:s}".format(
                    epoch_str, step_str, fetchs_str, ips_info, eta_str))
            else:
                logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
                    mode, idx, fetchs_str, ips_info))

    end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                       [metric_list['batch_time'].total])
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_list["batch_time"].count /
        metric_list["batch_time"].sum)

    if mode == 'eval':
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))

    # return top1_acc in order to save the best model
    if mode == 'valid':
        if multilabel:
            return metric_list['multilabel_accuracy'].avg
        else:
            return metric_list['top1'].avg
Exemplo n.º 4
0
def run(dataloader,
        exe,
        program,
        feeds,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None,
        lr_scheduler=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle io dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [
        ("lr", AverageMeter('lr', 'f', postfix=",", need_avg=False)),
        ("batch_time", AverageMeter('batch_cost', '.5f', postfix=" s,")),
        ("reader_time", AverageMeter('reader_cost', '.5f', postfix=" s,")),
    ]
    topk_name = 'top{}'.format(config.topk)
    metric_list.insert(0, ("loss", fetchs["loss"][1]))
    use_mix = config.get("use_mix", False) and mode == "train"
    if not use_mix:
        metric_list.insert(0, (topk_name, fetchs[topk_name][1]))
        metric_list.insert(0, ("top1", fetchs["top1"][1]))

    metric_list = OrderedDict(metric_list)

    for m in metric_list.values():
        m.reset()

    use_dali = config.get('use_dali', False)
    dataloader = dataloader if use_dali else dataloader()
    tic = time.time()

    idx = 0
    batch_size = None
    while True:
        # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
        try:
            batch = next(dataloader)
        except StopIteration:
            break
        except RuntimeError:
            logger.warning(
                "Except RuntimeError when reading data from dataloader, try to read once again..."
            )
            continue
        idx += 1
        # ignore the warmup iters
        if idx == 5:
            metric_list["batch_time"].reset()
            metric_list["reader_time"].reset()

        metric_list['reader_time'].update(time.time() - tic)

        if use_dali:
            batch_size = batch[0]["feed_image"].shape()[0]
            feed_dict = batch[0]
        else:
            batch_size = batch[0].shape()[0]
            feed_dict = {
                key.name: batch[idx]
                for idx, key in enumerate(feeds.values())
            }
        metrics = exe.run(program=program,
                          feed=feed_dict,
                          fetch_list=fetch_list)

        for name, m in zip(fetchs.keys(), metrics):
            metric_list[name].update(np.mean(m), batch_size)
        metric_list["batch_time"].update(time.time() - tic)
        if mode == "train":
            metric_list['lr'].update(lr_scheduler.get_lr())

        fetchs_str = ' '.join([
            str(metric_list[key].mean)
            if "time" in key else str(metric_list[key].value)
            for key in metric_list
        ])
        ips_info = " ips: {:.5f} images/sec.".format(
            batch_size / metric_list["batch_time"].avg)
        fetchs_str += ips_info

        if lr_scheduler is not None:
            if lr_scheduler.update_specified:
                curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                update = max(
                    0, curr_global_counter - lr_scheduler.update_start_step
                ) % lr_scheduler.update_step_interval == 0
                if update:
                    lr_scheduler.step()
            else:
                lr_scheduler.step()

        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'valid':
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(
                    mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(
                    logger.coloring(epoch_str, "HEADER")
                    if idx == 0 else epoch_str,
                    logger.coloring(step_str, "PURPLE"),
                    logger.coloring(fetchs_str, 'OKGREEN')))

        tic = time.time()

    end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
                       [metric_list["batch_time"].total])
    ips_info = "ips: {:.5f} images/sec.".format(
        batch_size * metric_list["batch_time"].count /
        metric_list["batch_time"].sum)
    if mode == 'valid':
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
    if use_dali:
        dataloader.reset()

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg
Exemplo n.º 5
0
def run(dataloader,
        exe,
        program,
        feeds,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None,
        lr_scheduler=None,
        profiler_options=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle io dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or evaluation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_dict = OrderedDict([("lr",
                                AverageMeter('lr',
                                             'f',
                                             postfix=",",
                                             need_avg=False))])

    for k in fetchs:
        metric_dict[k] = fetchs[k][1]

    metric_dict["batch_time"] = AverageMeter('batch_cost',
                                             '.5f',
                                             postfix=" s,")
    metric_dict["reader_time"] = AverageMeter('reader_cost',
                                              '.5f',
                                              postfix=" s,")

    for m in metric_dict.values():
        m.reset()

    use_dali = config["Global"].get('use_dali', False)
    tic = time.time()

    if not use_dali:
        dataloader = dataloader()

    idx = 0
    batch_size = None
    while True:
        # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
        try:
            batch = next(dataloader)
        except StopIteration:
            break
        except RuntimeError:
            logger.warning(
                "Except RuntimeError when reading data from dataloader, try to read once again..."
            )
            continue
        idx += 1
        # ignore the warmup iters
        if idx == 5:
            metric_dict["batch_time"].reset()
            metric_dict["reader_time"].reset()

        metric_dict['reader_time'].update(time.time() - tic)

        profiler.add_profiler_step(profiler_options)

        if use_dali:
            batch_size = batch[0]["data"].shape()[0]
            feed_dict = batch[0]
        else:
            batch_size = batch[0].shape()[0]
            feed_dict = {
                key.name: batch[idx]
                for idx, key in enumerate(feeds.values())
            }

        metrics = exe.run(program=program,
                          feed=feed_dict,
                          fetch_list=fetch_list)

        for name, m in zip(fetchs.keys(), metrics):
            metric_dict[name].update(np.mean(m), batch_size)
        metric_dict["batch_time"].update(time.time() - tic)
        if mode == "train":
            metric_dict['lr'].update(lr_scheduler.get_lr())

        fetchs_str = ' '.join([
            str(metric_dict[key].mean)
            if "time" in key else str(metric_dict[key].value)
            for key in metric_dict
        ])
        ips_info = " ips: {:.5f} images/sec.".format(
            batch_size / metric_dict["batch_time"].avg)
        fetchs_str += ips_info

        if lr_scheduler is not None:
            lr_scheduler.step()

        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'eval':
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(
                    mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(epoch_str, step_str,
                                                    fetchs_str))

        tic = time.time()

    end_str = ' '.join([str(m.mean) for m in metric_dict.values()] +
                       [metric_dict["batch_time"].total])
    ips_info = "ips: {:.5f} images/sec.".format(batch_size /
                                                metric_dict["batch_time"].avg)
    if mode == 'eval':
        logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
    if use_dali:
        dataloader.reset()

    # return top1_acc in order to save the best model
    if mode == 'eval':
        return fetchs["top1"][1].avg
Exemplo n.º 6
0
def run(dataloader,
        exe,
        program,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(fluid dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('elapse', '.5f', need_avg=True)
    tic = time.time()
    dataloader = dataloader if config.get('use_dali') else dataloader()()
    for idx, batch in enumerate(dataloader):
        if idx == 10:
            for m in metric_list:
                m.reset()
            batch_time.reset()
        batch_size = batch[0]["feed_image"].shape()[0]
        metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        for i, m in enumerate(metrics):
            metric_list[i].update(np.mean(m), batch_size)
        fetchs_str = ''.join([str(m.value) + ' '
                              for m in metric_list] + [batch_time.mean]) + 's'
        ips_info = " ips: {:.5f} images/sec.".format(batch_size /
                                                     batch_time.avg)
        fetchs_str += ips_info
        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'eval':
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(
                    mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(
                    epoch_str if idx == 0 else epoch_str, step_str,
                    fetchs_str))
        tic = time.time()

    if config.get('use_dali'):
        dataloader.reset()

    end_str = ''.join([str(m.mean) + ' '
                       for m in metric_list] + [batch_time.total]) + 's'
    ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
                                                batch_time.sum)

    if mode == 'eval':
        logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg
Exemplo n.º 7
0
def run(dataloader,
        exe,
        program,
        feeds,
        fetchs,
        epoch=0,
        mode='train',
        config=None,
        vdl_writer=None,
        lr_scheduler=None):
    """
    Feed data to the model and fetch the measures and loss

    Args:
        dataloader(paddle io dataloader):
        exe():
        program():
        fetchs(dict): dict of measures and the loss
        epoch(int): epoch of training or validation
        model(str): log only

    Returns:
    """
    fetch_list = [f[0] for f in fetchs.values()]
    metric_list = [f[1] for f in fetchs.values()]
    if mode == "train":
        metric_list.append(AverageMeter('lr', 'f', need_avg=False))
    for m in metric_list:
        m.reset()
    batch_time = AverageMeter('elapse', '.3f')
    use_dali = config.get('use_dali', False)
    dataloader = dataloader if use_dali else dataloader()
    tic = time.time()
    for idx, batch in enumerate(dataloader):
        # ignore the warmup iters
        if idx == 5:
            batch_time.reset()
        if use_dali:
            batch_size = batch[0]["feed_image"].shape()[0]
            feed_dict = batch[0]
        else:
            batch_size = batch[0].shape()[0]
            feed_dict = {
                key.name: batch[idx]
                for idx, key in enumerate(feeds.values())
            }
        metrics = exe.run(program=program,
                          feed=feed_dict,
                          fetch_list=fetch_list)
        batch_time.update(time.time() - tic)
        for i, m in enumerate(metrics):
            metric_list[i].update(np.mean(m), batch_size)

        if mode == "train":
            metric_list[-1].update(lr_scheduler.get_lr())

        fetchs_str = ''.join([str(m.value) + ' '
                              for m in metric_list] + [batch_time.mean]) + 's'
        ips_info = " ips: {:.5f} images/sec.".format(batch_size /
                                                     batch_time.avg)
        fetchs_str += ips_info

        if lr_scheduler is not None:
            if lr_scheduler.update_specified:
                curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
                update = max(
                    0, curr_global_counter - lr_scheduler.update_start_step
                ) % lr_scheduler.update_step_interval == 0
                if update:
                    lr_scheduler.step()
            else:
                lr_scheduler.step()

        if vdl_writer:
            global total_step
            logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
            total_step += 1
        if mode == 'valid':
            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} step:{:<4d} {:s}".format(
                    mode, idx, fetchs_str))
        else:
            epoch_str = "epoch:{:<3d}".format(epoch)
            step_str = "{:s} step:{:<4d}".format(mode, idx)

            if idx % config.get('print_interval', 10) == 0:
                logger.info("{:s} {:s} {:s}".format(
                    logger.coloring(epoch_str, "HEADER")
                    if idx == 0 else epoch_str,
                    logger.coloring(step_str, "PURPLE"),
                    logger.coloring(fetchs_str, 'OKGREEN')))

        tic = time.time()

    end_str = ''.join([str(m.mean) + ' '
                       for m in metric_list] + [batch_time.total]) + 's'
    ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count /
                                                batch_time.sum)
    if mode == 'valid':
        logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info))
    else:
        end_epoch_str = "END epoch:{:<3d}".format(epoch)
        logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
                                                 ips_info))
    if use_dali:
        dataloader.reset()

    # return top1_acc in order to save the best model
    if mode == 'valid':
        return fetchs["top1"][1].avg
Exemplo n.º 8
0
    def train(self):
        assert self.mode == "train"
        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]
        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key:
        # val: metrics list word
        self.output_info = dict()
        self.time_info = {
            "batch_cost": AverageMeter("batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter("reader_cost", ".5f", postfix=" s,"),
        }
        # global iter counter
        self.global_step = 0

        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     self.optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

        self.max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
            # for one epoch train
            self.train_epoch_func(self, epoch_id, print_batch_step)

            if self.use_dali:
                self.train_dataloader.reset()
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, self.output_info[key].avg)
                for key in self.output_info
            ])
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
            self.output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
                        "eval_interval"] == 0:
                acc = self.eval(epoch_id)
                if acc > best_metric["metric"]:
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        self.optimizer,
                        best_metric,
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
                    epoch_id, best_metric["metric"]))
                logger.scaler(name="eval_acc",
                              value=acc,
                              step=epoch_id,
                              writer=self.vdl_writer)

                self.model.train()

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(self.model,
                                     self.optimizer, {
                                         "metric": acc,
                                         "epoch": epoch_id
                                     },
                                     self.output_dir,
                                     model_name=self.config["Arch"]["name"],
                                     prefix="epoch_{}".format(epoch_id))
            # save the latest model
            save_load.save_model(self.model,
                                 self.optimizer, {
                                     "metric": acc,
                                     "epoch": epoch_id
                                 },
                                 self.output_dir,
                                 model_name=self.config["Arch"]["name"],
                                 prefix="latest")

        if self.vdl_writer is not None:
            self.vdl_writer.close()
Exemplo n.º 9
0
def main(args):
    config = get_config(args.config, overrides=args.override, show=True)
    # 如果需要量化训练,就必须开启评估
    if not config.validate and args.use_quant:
        logger.error("=====>Train quant model must use validate!")
        sys.exit(1)
    if args.use_quant:
        config.epochs = config.epochs + 5
        gpu_count = get_gpu_count()
        if gpu_count != 1:
            logger.error(
                "=====>`Train quant model must use only one GPU. "
                "Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` ."
            )
            sys.exit(1)

    # 设置是否使用 GPU
    use_gpu = config.get("use_gpu", True)
    places = fluid.cuda_places() if use_gpu else fluid.cpu_places()

    startup_prog = fluid.Program()
    train_prog = fluid.Program()

    best_top1_acc = 0.0

    # 获取训练数据和模型输出
    if not config.get('use_ema'):
        train_dataloader, train_fetchs, out, softmax_out = program.build(
            config,
            train_prog,
            startup_prog,
            is_train=True,
            is_distributed=False)
    else:
        train_dataloader, train_fetchs, ema, out, softmax_out = program.build(
            config,
            train_prog,
            startup_prog,
            is_train=True,
            is_distributed=False)
    # 获取评估数据和模型输出
    if config.validate:
        valid_prog = fluid.Program()
        valid_dataloader, valid_fetchs, _, _ = program.build(
            config,
            valid_prog,
            startup_prog,
            is_train=False,
            is_distributed=False)
        # 克隆评估程序,可以去掉与评估无关的计算
        valid_prog = valid_prog.clone(for_test=True)

    # 创建执行器
    exe = fluid.Executor(places[0])
    exe.run(startup_prog)

    # 加载模型,可以是预训练模型,也可以是检查点
    init_model(config, train_prog, exe)

    train_reader = Reader(config, 'train')()
    train_dataloader.set_sample_list_generator(train_reader, places)

    compiled_train_prog = program.compile(config, train_prog,
                                          train_fetchs['loss'][0].name)

    if config.validate:
        valid_reader = Reader(config, 'valid')()
        valid_dataloader.set_sample_list_generator(valid_reader, places)
        compiled_valid_prog = program.compile(config,
                                              valid_prog,
                                              share_prog=compiled_train_prog)

    vdl_writer = LogWriter(args.vdl_dir)

    for epoch_id in range(config.epochs - 5):
        # 训练一轮
        program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
                    epoch_id, 'train', config, vdl_writer)

        # 执行一次评估
        if config.validate and epoch_id % config.valid_interval == 0:
            if config.get('use_ema'):
                logger.info(logger.coloring("EMA validate start..."))
                with ema.apply(exe):
                    _ = program.run(valid_dataloader, exe, compiled_valid_prog,
                                    valid_fetchs, epoch_id, 'valid', config)
                logger.info(logger.coloring("EMA validate over!"))

            top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog,
                                   valid_fetchs, epoch_id, 'valid', config)

            if vdl_writer:
                logger.scaler('valid_avg', top1_acc, epoch_id, vdl_writer)

            if top1_acc > best_top1_acc:
                best_top1_acc = top1_acc
                message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
                    best_top1_acc, epoch_id)
                logger.info("{:s}".format(logger.coloring(message, "RED")))
                if epoch_id % config.save_interval == 0:
                    model_path = os.path.join(config.model_save_dir,
                                              config.ARCHITECTURE["name"])
                    save_model(train_prog, model_path, "best_model")

        # 保存模型
        if epoch_id % config.save_interval == 0:
            model_path = os.path.join(config.model_save_dir,
                                      config.ARCHITECTURE["name"])
            if epoch_id >= 3 and os.path.exists(
                    os.path.join(model_path, str(epoch_id - 3))):
                shutil.rmtree(os.path.join(model_path, str(epoch_id - 3)),
                              ignore_errors=True)
            save_model(train_prog, model_path, epoch_id)

    # 量化训练
    if args.use_quant and config.validate:
        # 执行量化训练
        quant_program = slim.quant.quant_aware(train_prog,
                                               exe.place,
                                               for_test=False)
        # 评估量化的结果
        val_quant_program = slim.quant.quant_aware(valid_prog,
                                                   exe.place,
                                                   for_test=True)

        fetch_list = [f[0] for f in train_fetchs.values()]
        metric_list = [f[1] for f in train_fetchs.values()]
        for i in range(5):
            for idx, batch in enumerate(train_dataloader()):
                metrics = exe.run(program=quant_program,
                                  feed=batch,
                                  fetch_list=fetch_list)
                for i, m in enumerate(metrics):
                    metric_list[i].update(np.mean(m), len(batch[0]))
                fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list])

                if idx % 10 == 0:
                    logger.info("quant train : " + fetchs_str)

        fetch_list = [f[0] for f in valid_fetchs.values()]
        metric_list = [f[1] for f in valid_fetchs.values()]
        for idx, batch in enumerate(valid_dataloader()):
            metrics = exe.run(program=val_quant_program,
                              feed=batch,
                              fetch_list=fetch_list)
            for i, m in enumerate(metrics):
                metric_list[i].update(np.mean(m), len(batch[0]))
            fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list])

            if idx % 10 == 0:
                logger.info("quant valid: " + fetchs_str)

        # 保存量化训练模型
        float_prog, int8_prog = slim.quant.convert(val_quant_program,
                                                   exe.place,
                                                   save_int8=True)
        fluid.io.save_inference_model(dirname=args.output_path,
                                      feeded_var_names=['feed_image'],
                                      target_vars=[softmax_out],
                                      executor=exe,
                                      main_program=float_prog,
                                      model_filename='__model__',
                                      params_filename='__params__')