def train_dali(cfg, weights=None, parallel=True):
    """Train model entry

    Args:
    	cfg (dict): configuration.
        weights (str): weights path for finetuning.
    	parallel (bool): Whether multi-cards training. Default: True.

    """

    logger = get_logger("paddlevideo")
    batch_size = cfg.DALI_LOADER.get('batch_size', 8)
    places = paddle.set_device('gpu')
    model_name = cfg.model_name
    output_dir = cfg.get("output_dir", f"./output/{model_name}")
    mkdir(output_dir)

    # 1. Construct model
    model = build_model(cfg.MODEL)
    if parallel:
        model = paddle.DataParallel(model)

    # 2. Construct dali dataloader
    train_loader = TSN_Dali_loader(cfg.DALI_LOADER).build_dali_reader()

    # 3. Construct solver.
    lr = build_lr(cfg.OPTIMIZER.learning_rate, None)
    optimizer = build_optimizer(cfg.OPTIMIZER,
                                lr,
                                parameter_list=model.parameters())

    # Resume
    resume_epoch = cfg.get("resume_epoch", 0)
    if resume_epoch:
        filename = osp.join(output_dir,
                            model_name + f"_epoch_{resume_epoch:05d}")
        resume_model_dict = load(filename + '.pdparams')
        resume_opt_dict = load(filename + '.pdopt')
        model.set_state_dict(resume_model_dict)
        optimizer.set_state_dict(resume_opt_dict)

    # Finetune:
    if weights:
        assert resume_epoch == 0, f"Conflict occurs when finetuning, please switch resume function off by setting resume_epoch to 0 or not indicating it."
        model_dict = load(weights)
        model.set_state_dict(model_dict)

    # 4. Train Model
    for epoch in range(0, cfg.epochs):
        if epoch < resume_epoch:
            logger.info(
                f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
            )
            continue
        model.train()
        record_list = build_record(cfg.MODEL)
        tic = time.time()
        for i, data in enumerate(train_loader):
            data = get_input_data(data)
            record_list['reader_time'].update(time.time() - tic)
            # 4.1 forward
            if parallel:
                outputs = model._layers.train_step(data)
                ## required for DataParallel, will remove in next version
                model._reducer.prepare_for_backward(
                    list(model._find_varbase(outputs)))
            else:
                outputs = model.train_step(data)
            # 4.2 backward
            avg_loss = outputs['loss']
            avg_loss.backward()
            # 4.3 minimize
            optimizer.step()
            optimizer.clear_grad()

            # log record
            record_list['lr'].update(optimizer._global_learning_rate(),
                                     batch_size)
            for name, value in outputs.items():
                record_list[name].update(value, batch_size)

            record_list['batch_time'].update(time.time() - tic)
            tic = time.time()

            if i % cfg.get("log_interval", 10) == 0:
                ips = "ips: {:.5f} instance/sec.".format(
                    batch_size / record_list["batch_time"].val)
                log_batch(record_list, i, epoch + 1, cfg.epochs, "train", ips)

            # learning rate iter step
            if cfg.OPTIMIZER.learning_rate.get("iter_step"):
                lr.step()

        # learning rate epoch step
        if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
            lr.step()

        ips = "ips: {:.5f} instance/sec.".format(
            batch_size * record_list["batch_time"].count /
            record_list["batch_time"].sum)
        log_epoch(record_list, epoch + 1, "train", ips)

        # use precise bn to improve acc
        if cfg.get("PRECISEBN") and (epoch % cfg.PRECISEBN.preciseBN_interval
                                     == 0 or epoch == cfg.epochs - 1):
            do_preciseBN(
                model, train_loader, parallel,
                min(cfg.PRECISEBN.num_iters_preciseBN, len(train_loader)))

        # 5. Save model and optimizer
        if epoch % cfg.get("save_interval", 1) == 0 or epoch == cfg.epochs - 1:
            save(
                optimizer.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdopt"))
            save(
                model.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdparams"))

    logger.info(f'training {model_name} finished')
def train_model_multigrid(cfg, world_size=1, validate=True):
    """Train model entry

    Args:
    	cfg (dict): configuration.
    	parallel (bool): Whether multi-card training. Default: Treu
        validate (bool): Whether to do evaluation. Default: False.

    """
    # Init multigrid.
    multigrid = None
    if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
        multigrid = MultigridSchedule()
        cfg = multigrid.init_multigrid(cfg)
        if cfg.MULTIGRID.LONG_CYCLE:
            cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
    multi_save_epoch = [i[-1] - 1 for i in multigrid.schedule]

    parallel = world_size != 1
    logger = get_logger("paddlevideo")
    batch_size = cfg.DATASET.get('batch_size', 2)
    places = paddle.set_device('gpu')
    model_name = cfg.model_name
    output_dir = cfg.get("output_dir", f"./output/{model_name}")
    mkdir(output_dir)
    local_rank = dist.ParallelEnv().local_rank
    precise_bn = cfg.get("PRECISEBN")
    num_iters_precise_bn = cfg.PRECISEBN.num_iters_preciseBN

    # 1. Construct model
    model = build_model(cfg.MODEL)
    if parallel:
        model = paddle.DataParallel(model)

    # 2. Construct dataloader
    train_loader, valid_loader, precise_bn_loader = \
        construct_loader(cfg,
                         places,
                         validate,
                         precise_bn,
                         num_iters_precise_bn,
                         world_size,
                         )

    # 3. Construct optimizer
    lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
    optimizer = build_optimizer(cfg.OPTIMIZER,
                                lr,
                                parameter_list=model.parameters())

    # Resume
    resume_epoch = cfg.get("resume_epoch", 0)
    if resume_epoch:
        filename = osp.join(
            output_dir,
            model_name + str(local_rank) + '_' + f"{resume_epoch:05d}")
        subn_load(model, filename, optimizer)

    # 4. Train Model
    best = 0.
    total_epochs = int(cfg.epochs * cfg.MULTIGRID.epoch_factor)
    for epoch in range(total_epochs):
        if epoch < resume_epoch:
            logger.info(
                f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
            )
            continue

        if cfg.MULTIGRID.LONG_CYCLE:
            cfg, changed = multigrid.update_long_cycle(cfg, epoch)
            if changed:
                logger.info("====== Rebuild model/optimizer/loader =====")
                (
                    model,
                    lr,
                    optimizer,
                    train_loader,
                    valid_loader,
                    precise_bn_loader,
                ) = build_trainer(cfg, places, parallel, validate, precise_bn,
                                  num_iters_precise_bn, world_size)

                #load checkpoint after re-build model
                if epoch != 0:
                    #epoch no need to -1, haved add 1 when save
                    filename = osp.join(
                        output_dir,
                        model_name + str(local_rank) + '_' + f"{(epoch):05d}")
                    subn_load(model, filename, optimizer)
                #update lr last epoch, not to use saved params
                lr.last_epoch = epoch
                lr.step(rebuild=True)

        model.train()
        record_list = build_record(cfg.MODEL)
        tic = time.time()
        for i, data in enumerate(train_loader):
            record_list['reader_time'].update(time.time() - tic)
            # 4.1 forward
            if parallel:
                outputs = model._layers.train_step(data)
                ## required for DataParallel, will remove in next version
                model._reducer.prepare_for_backward(
                    list(model._find_varbase(outputs)))
            else:
                outputs = model.train_step(data)
            # 4.2 backward
            avg_loss = outputs['loss']
            avg_loss.backward()
            # 4.3 minimize
            optimizer.step()
            optimizer.clear_grad()

            # log record
            record_list['lr'].update(
                optimizer._global_learning_rate().numpy()[0], batch_size)
            for name, value in outputs.items():
                record_list[name].update(value.numpy()[0], batch_size)
            record_list['batch_time'].update(time.time() - tic)
            tic = time.time()

            if i % cfg.get("log_interval", 10) == 0:
                ips = "ips: {:.5f} instance/sec.".format(
                    batch_size / record_list["batch_time"].val)
                log_batch(record_list, i, epoch + 1, total_epochs, "train", ips)

            # learning rate iter step
            if cfg.OPTIMIZER.learning_rate.get("iter_step"):
                lr.step()

        # learning rate epoch step
        if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
            lr.step()

        ips = "ips: {:.5f} instance/sec.".format(
            batch_size * record_list["batch_time"].count /
            record_list["batch_time"].sum)
        log_epoch(record_list, epoch + 1, "train", ips)

        def evaluate(best):
            model.eval()
            record_list = build_record(cfg.MODEL)
            record_list.pop('lr')
            tic = time.time()
            for i, data in enumerate(valid_loader):
                if parallel:
                    outputs = model._layers.val_step(data)
                else:
                    outputs = model.val_step(data)

                # log_record
                for name, value in outputs.items():
                    record_list[name].update(value.numpy()[0], batch_size)

                record_list['batch_time'].update(time.time() - tic)
                tic = time.time()

                if i % cfg.get("log_interval", 10) == 0:
                    ips = "ips: {:.5f} instance/sec.".format(
                        batch_size / record_list["batch_time"].val)
                    log_batch(record_list, i, epoch + 1, total_epochs, "val",
                              ips)

            ips = "ips: {:.5f} instance/sec.".format(
                batch_size * record_list["batch_time"].count /
                record_list["batch_time"].sum)
            log_epoch(record_list, epoch + 1, "val", ips)

            best_flag = False
            if record_list.get('top1') and record_list['top1'].avg > best:
                best = record_list['top1'].avg
                best_flag = True
            return best, best_flag

        # use precise bn to improve acc
        if is_eval_epoch(cfg, epoch, total_epochs, multigrid.schedule):
            logger.info(f"do precise BN in {epoch+1} ...")
            do_preciseBN(model, precise_bn_loader, parallel,
                         min(num_iters_precise_bn, len(precise_bn_loader)))

        #  aggregate sub_BN stats
        logger.info("Aggregate sub_BatchNorm stats...")
        aggregate_sub_bn_stats(model)

        # 5. Validation
        if is_eval_epoch(cfg, epoch, total_epochs, multigrid.schedule):
            logger.info(f"eval in {epoch+1} ...")
            with paddle.fluid.dygraph.no_grad():
                best, save_best_flag = evaluate(best)
            # save best
            if save_best_flag:
                save(optimizer.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdopt"))
                save(model.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdparams"))
                logger.info(
                    f"Already save the best model (top1 acc){int(best * 10000) / 10000}"
                )

        # 6. Save model and optimizer
        if is_eval_epoch(
                cfg, epoch,
                total_epochs, multigrid.schedule) or epoch % cfg.get(
                    "save_interval", 10) == 0 or epoch in multi_save_epoch:
            logger.info("[Save parameters] ======")
            subn_save(output_dir, model_name + str(local_rank) + '_', epoch + 1,
                      model, optimizer)

    logger.info(f'training {model_name} finished')
Exemple #3
0
def train_model(cfg, parallel=True, validate=True):
    """Train model entry

    Args:
    	cfg (dict): configuration.
    	parallel (bool): Whether multi-card training. Default: True.
        validate (bool): Whether to do evaluation. Default: False.

    """
    head_name = cfg.MODEL.head.name

    logger = get_logger("paddlevideo")
    #single card batch size
    batch_size = cfg.DATASET.get('batch_size', 8)
    valid_batch_size = cfg.DATASET.get('valid_batch_size', batch_size)
    places = paddle.set_device('gpu')

    # default num worker: 0, which means no subprocess will be created
    num_workers = cfg.DATASET.get('num_workers', 0)
    model_name = cfg.model_name
    output_dir = cfg.get("output_dir", f"./output/{model_name}")
    mkdir(output_dir)

    # 1. Construct model
    model = build_model(cfg.MODEL)
    if parallel:
        model = paddle.DataParallel(model)

    train_shuffle = True
    valid_shuffle = False
    if "shuffle" in cfg.DATASET.train:
        train_shuffle = cfg.DATASET.train.pop("shuffle")
    if "shuffle" in cfg.DATASET.valid:
        valid_shuffle = cfg.DATASET.valid.pop("shuffle")

    # 2. Construct dataset and dataloader
    train_dataset = build_dataset((cfg.DATASET.train, cfg.PIPELINE.train))
    train_dataloader_setting = dict(batch_size=batch_size,
                                    num_workers=num_workers,
                                    collate_fn_cfg=cfg.get('MIX', None),
                                    places=places,
                                    shuffle=train_shuffle)

    train_loader = build_dataloader(train_dataset, **train_dataloader_setting)
    if validate:
        valid_dataset = build_dataset((cfg.DATASET.valid, cfg.PIPELINE.valid))
        validate_dataloader_setting = dict(batch_size=valid_batch_size,
                                           num_workers=num_workers,
                                           places=places,
                                           drop_last=False,
                                           shuffle=valid_shuffle)
        valid_loader = build_dataloader(valid_dataset,
                                        **validate_dataloader_setting)

    # 3. Construct learning rate schedule and optimizer
    lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
    optimizer = build_optimizer(cfg.OPTIMIZER,
                                lr,
                                parameter_list=model.parameters())

    # Resume
    resume_epoch = cfg.get("resume_epoch", 0)
    if resume_epoch:
        filename = osp.join(output_dir,
                            model_name + f"_epoch_{resume_epoch:05d}")
        resume_model_dict = load(filename + '.pdparams')
        resume_opt_dict = load(filename + '.pdopt')
        model.set_state_dict(resume_model_dict)
        optimizer.set_state_dict(resume_opt_dict)

    # 4. Train Model
    best = 0.
    for epoch in range(0, cfg.epochs):
        if epoch < resume_epoch:
            logger.info(
                f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
            )
            continue
        model.train()
        record_list = build_record(cfg.MODEL)
        tic = time.time()
        for i, data in enumerate(train_loader):
            record_list['reader_time'].update(time.time() - tic)
            # 4.1 forward
            if parallel:
                outputs = model._layers.train_step(data)
            else:
                outputs = model.train_step(data)
            # 4.2 backward
            avg_loss = outputs['loss']
            avg_loss.backward()
            # 4.3 minimize
            optimizer.step()
            optimizer.clear_grad()

            # log record
            record_list['lr'].update(
                optimizer._global_learning_rate().numpy()[0], batch_size)
            for name, value in outputs.items():
                if name is 'hit_at_one' or name is 'perr' or name is 'gap':
                    record_list[name].update(value, batch_size)
                else:
                    record_list[name].update(value.numpy()[0], batch_size)

            record_list['batch_time'].update(time.time() - tic)
            tic = time.time()

            if i % cfg.get("log_interval", 10) == 0:
                ips = "ips: {:.5f} instance/sec.".format(
                    batch_size / record_list["batch_time"].val)
                log_batch(record_list, i, epoch + 1, cfg.epochs, "train", ips)

            # learning rate iter step
            if cfg.OPTIMIZER.learning_rate.get("iter_step"):
                lr.step()

        # learning rate epoch step
        if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
            lr.step()

        ips = "ips: {:.5f} instance/sec.".format(
            batch_size * record_list["batch_time"].count /
            record_list["batch_time"].sum)
        log_epoch(record_list, epoch + 1, "train", ips)

        def evaluate(best):
            model.eval()
            record_list = build_record(cfg.MODEL)
            record_list.pop('lr')
            tic = time.time()
            for i, data in enumerate(valid_loader):

                if parallel:
                    outputs = model._layers.val_step(data)
                else:
                    outputs = model.val_step(data)

                # log_record
                for name, value in outputs.items():
                    if name is 'hit_at_one' or name is 'perr' or name is 'gap':
                        record_list[name].update(value, batch_size)
                    else:
                        record_list[name].update(value.numpy()[0], batch_size)

                record_list['batch_time'].update(time.time() - tic)
                tic = time.time()

                if i % cfg.get("log_interval", 10) == 0:
                    ips = "ips: {:.5f} instance/sec.".format(
                        batch_size / record_list["batch_time"].val)
                    log_batch(record_list, i, epoch + 1, cfg.epochs, "val",
                              ips)

            ips = "ips: {:.5f} instance/sec.".format(
                batch_size * record_list["batch_time"].count /
                record_list["batch_time"].sum)
            log_epoch(record_list, epoch + 1, "val", ips)

            best_flag = False
            if head_name == "AttentionLstmHead":
                if record_list.get(
                        'hit_at_one') and record_list['hit_at_one'].avg > best:
                    best = record_list['hit_at_one'].avg
                    best_flag = True
            else:
                if record_list.get('top1') and record_list['top1'].avg > best:
                    best = record_list['top1'].avg
                    best_flag = True

            return best, best_flag

        # use precise bn to improve acc
        if cfg.get("PRECISEBN") and (epoch % cfg.PRECISEBN.preciseBN_interval
                                     == 0 or epoch == cfg.epochs - 1):
            do_preciseBN(
                model, train_loader, parallel,
                min(cfg.PRECISEBN.num_iters_preciseBN, len(train_loader)))

        # 5. Validation
        if validate and epoch % cfg.get("val_interval",
                                        1) == 0 or epoch == cfg.epochs - 1:
            with paddle.fluid.dygraph.no_grad():
                best, save_best_flag = evaluate(best)
            # save best
            if save_best_flag:
                save(optimizer.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdopt"))
                save(model.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdparams"))
                if head_name == "AttentionLstmHead":
                    logger.info(
                        f"Already save the best model (hit_at_one){best}")
                else:
                    logger.info(
                        f"Already save the best model (top1 acc){best}")

        # 6. Save model and optimizer
        if epoch % cfg.get("save_interval", 1) == 0 or epoch == cfg.epochs - 1:
            save(
                optimizer.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdopt"))
            save(
                model.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdparams"))

    logger.info(f'training {model_name} finished')
Exemple #4
0
def train_model(cfg,
                weights=None,
                parallel=True,
                validate=True,
                amp=False,
                fleet=False):
    """Train model entry

    Args:
    	cfg (dict): configuration.
        weights (str): weights path for finetuning.
    	parallel (bool): Whether multi-cards training. Default: True.
        validate (bool): Whether to do evaluation. Default: False.

    """
    if fleet:
        fleet.init(is_collective=True)

    logger = get_logger("paddlevideo")
    batch_size = cfg.DATASET.get('batch_size', 8)
    valid_batch_size = cfg.DATASET.get('valid_batch_size', batch_size)
    places = paddle.set_device('gpu')

    # default num worker: 0, which means no subprocess will be created
    num_workers = cfg.DATASET.get('num_workers', 0)
    model_name = cfg.model_name
    output_dir = cfg.get("output_dir", f"./output/{model_name}")
    mkdir(output_dir)

    # 1. Construct model
    model = build_model(cfg.MODEL)
    if parallel:
        model = paddle.DataParallel(model)

    if fleet:
        model = paddle.distributed_model(model)

    # 2. Construct dataset and dataloader
    train_dataset = build_dataset((cfg.DATASET.train, cfg.PIPELINE.train))
    train_dataloader_setting = dict(batch_size=batch_size,
                                    num_workers=num_workers,
                                    collate_fn_cfg=cfg.get('MIX', None),
                                    places=places)

    train_loader = build_dataloader(train_dataset, **train_dataloader_setting)
    if validate:
        valid_dataset = build_dataset((cfg.DATASET.valid, cfg.PIPELINE.valid))
        validate_dataloader_setting = dict(
            batch_size=valid_batch_size,
            num_workers=num_workers,
            places=places,
            drop_last=False,
            shuffle=cfg.DATASET.get(
                'shuffle_valid',
                False)  #NOTE: attention lstm need shuffle valid data.
        )
        valid_loader = build_dataloader(valid_dataset,
                                        **validate_dataloader_setting)

    # 3. Construct solver.
    lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
    optimizer = build_optimizer(cfg.OPTIMIZER,
                                lr,
                                parameter_list=model.parameters())
    if fleet:
        optimizer = fleet.distributed_optimizer(optimizer)
    # Resume
    resume_epoch = cfg.get("resume_epoch", 0)
    if resume_epoch:
        filename = osp.join(output_dir,
                            model_name + f"_epoch_{resume_epoch:05d}")
        resume_model_dict = load(filename + '.pdparams')
        resume_opt_dict = load(filename + '.pdopt')
        model.set_state_dict(resume_model_dict)
        optimizer.set_state_dict(resume_opt_dict)

    # Finetune:
    if weights:
        assert resume_epoch == 0, f"Conflict occurs when finetuning, please switch resume function off by setting resume_epoch to 0 or not indicating it."
        model_dict = load(weights)
        model.set_state_dict(model_dict)

    # 4. Train Model
    ###AMP###
    if amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

    best = 0.
    for epoch in range(0, cfg.epochs):
        if epoch < resume_epoch:
            logger.info(
                f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
            )
            continue
        model.train()
        record_list = build_record(cfg.MODEL)
        tic = time.time()
        for i, data in enumerate(train_loader):
            record_list['reader_time'].update(time.time() - tic)

            # 4.1 forward

            ###AMP###
            if amp:
                with paddle.amp.auto_cast(custom_black_list="temporal_shift"):
                    if parallel:
                        outputs = model._layers.train_step(data)
                    else:
                        outputs = model.train_step(data)

                avg_loss = outputs['loss']
                scaled = scaler.scale(avg_loss)
                scaled.backward()
                # keep prior to 2.0 design
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()

            else:
                if parallel:
                    outputs = model._layers.train_step(data)
                else:
                    outputs = model.train_step(data)

                # 4.2 backward
                avg_loss = outputs['loss']
                avg_loss.backward()
                # 4.3 minimize
                optimizer.step()
                optimizer.clear_grad()

            # log record
            record_list['lr'].update(optimizer._global_learning_rate(),
                                     batch_size)
            for name, value in outputs.items():
                record_list[name].update(value, batch_size)

            record_list['batch_time'].update(time.time() - tic)
            tic = time.time()

            if i % cfg.get("log_interval", 10) == 0:
                ips = "ips: {:.5f} instance/sec.".format(
                    batch_size / record_list["batch_time"].val)
                log_batch(record_list, i, epoch + 1, cfg.epochs, "train", ips)

            # learning rate iter step
            if cfg.OPTIMIZER.learning_rate.get("iter_step"):
                lr.step()

        # learning rate epoch step
        if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
            lr.step()

        ips = "ips: {:.5f} instance/sec.".format(
            batch_size * record_list["batch_time"].count /
            record_list["batch_time"].sum)
        log_epoch(record_list, epoch + 1, "train", ips)

        def evaluate(best):
            model.eval()
            record_list = build_record(cfg.MODEL)
            record_list.pop('lr')
            tic = time.time()
            for i, data in enumerate(valid_loader):

                if parallel:
                    outputs = model._layers.val_step(data)
                else:
                    outputs = model.val_step(data)

                # log_record
                for name, value in outputs.items():
                    record_list[name].update(value, batch_size)

                record_list['batch_time'].update(time.time() - tic)
                tic = time.time()

                if i % cfg.get("log_interval", 10) == 0:
                    ips = "ips: {:.5f} instance/sec.".format(
                        batch_size / record_list["batch_time"].val)
                    log_batch(record_list, i, epoch + 1, cfg.epochs, "val", ips)

            ips = "ips: {:.5f} instance/sec.".format(
                batch_size * record_list["batch_time"].count /
                record_list["batch_time"].sum)
            log_epoch(record_list, epoch + 1, "val", ips)

            best_flag = False
            for top_flag in ['hit_at_one', 'top1']:
                if record_list.get(
                        top_flag) and record_list[top_flag].avg > best:
                    best = record_list[top_flag].avg
                    best_flag = True

            return best, best_flag

        # use precise bn to improve acc
        if cfg.get("PRECISEBN") and (epoch % cfg.PRECISEBN.preciseBN_interval
                                     == 0 or epoch == cfg.epochs - 1):
            do_preciseBN(
                model, train_loader, parallel,
                min(cfg.PRECISEBN.num_iters_preciseBN, len(train_loader)))

        # 5. Validation
        if validate and epoch % cfg.get("val_interval",
                                        1) == 0 or epoch == cfg.epochs - 1:
            with paddle.fluid.dygraph.no_grad():
                best, save_best_flag = evaluate(best)
            # save best
            if save_best_flag:
                save(optimizer.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdopt"))
                save(model.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdparams"))
                if model_name == "AttentionLstm":
                    logger.info(
                        f"Already save the best model (hit_at_one){best}")
                else:
                    logger.info(
                        f"Already save the best model (top1 acc){int(best *10000)/10000}"
                    )

        # 6. Save model and optimizer
        if epoch % cfg.get("save_interval", 1) == 0 or epoch == cfg.epochs - 1:
            save(
                optimizer.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdopt"))
            save(
                model.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdparams"))

    logger.info(f'training {model_name} finished')
Exemple #5
0
def train_model(cfg,
                weights=None,
                parallel=True,
                validate=True,
                amp=False,
                use_fleet=False):
    """Train model entry

    Args:
    	cfg (dict): configuration.
        weights (str): weights path for finetuning.
    	parallel (bool): Whether multi-cards training. Default: True.
        validate (bool): Whether to do evaluation. Default: False.

    """
    if use_fleet:
        fleet.init(is_collective=True)

    logger = get_logger("paddlevideo")
    batch_size = cfg.DATASET.get('batch_size', 8)
    valid_batch_size = cfg.DATASET.get('valid_batch_size', batch_size)

    use_gradient_accumulation = cfg.get('GRADIENT_ACCUMULATION', None)
    if use_gradient_accumulation and dist.get_world_size() >= 1:
        global_batch_size = cfg.GRADIENT_ACCUMULATION.get(
            'global_batch_size', None)
        num_gpus = dist.get_world_size()

        assert isinstance(
            global_batch_size, int
        ), f"global_batch_size must be int, but got {type(global_batch_size)}"
        assert batch_size < global_batch_size, f"global_batch_size must bigger than batch_size"

        cur_global_batch_size = batch_size * num_gpus  # The number of batches calculated by all GPUs at one time
        assert global_batch_size % cur_global_batch_size == 0, \
            f"The global batchsize must be divisible by cur_global_batch_size, but \
                {global_batch_size} % {cur_global_batch_size} != 0"

        cfg.GRADIENT_ACCUMULATION[
            "num_iters"] = global_batch_size // cur_global_batch_size
        # The number of iterations required to reach the global batchsize
        logger.info(
            f"Using gradient accumulation training strategy, "
            f"global_batch_size={global_batch_size}, "
            f"num_gpus={num_gpus}, "
            f"num_accumulative_iters={cfg.GRADIENT_ACCUMULATION.num_iters}")

    places = paddle.set_device('gpu')

    # default num worker: 0, which means no subprocess will be created
    num_workers = cfg.DATASET.get('num_workers', 0)
    valid_num_workers = cfg.DATASET.get('valid_num_workers', num_workers)
    model_name = cfg.model_name
    output_dir = cfg.get("output_dir", f"./output/{model_name}")
    mkdir(output_dir)

    # 1. Construct model
    model = build_model(cfg.MODEL)
    if parallel:
        model = paddle.DataParallel(model)

    if use_fleet:
        model = paddle.distributed_model(model)

    # 2. Construct dataset and dataloader
    train_dataset = build_dataset((cfg.DATASET.train, cfg.PIPELINE.train))
    train_dataloader_setting = dict(batch_size=batch_size,
                                    num_workers=num_workers,
                                    collate_fn_cfg=cfg.get('MIX', None),
                                    places=places)

    train_loader = build_dataloader(train_dataset, **train_dataloader_setting)
    if validate:
        valid_dataset = build_dataset((cfg.DATASET.valid, cfg.PIPELINE.valid))
        validate_dataloader_setting = dict(
            batch_size=valid_batch_size,
            num_workers=valid_num_workers,
            places=places,
            drop_last=False,
            shuffle=cfg.DATASET.get(
                'shuffle_valid',
                False)  #NOTE: attention lstm need shuffle valid data.
        )
        valid_loader = build_dataloader(valid_dataset,
                                        **validate_dataloader_setting)

    # 3. Construct solver.
    lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
    optimizer = build_optimizer(cfg.OPTIMIZER,
                                lr,
                                parameter_list=model.parameters())
    if use_fleet:
        optimizer = fleet.distributed_optimizer(optimizer)
    # Resume
    resume_epoch = cfg.get("resume_epoch", 0)
    if resume_epoch:
        filename = osp.join(output_dir,
                            model_name + f"_epoch_{resume_epoch:05d}")
        resume_model_dict = load(filename + '.pdparams')
        resume_opt_dict = load(filename + '.pdopt')
        model.set_state_dict(resume_model_dict)
        optimizer.set_state_dict(resume_opt_dict)

    # Finetune:
    if weights:
        assert resume_epoch == 0, f"Conflict occurs when finetuning, please switch resume function off by setting resume_epoch to 0 or not indicating it."
        model_dict = load(weights)
        model.set_state_dict(model_dict)

    # 4. Train Model
    ###AMP###
    if amp:
        scaler = paddle.amp.GradScaler(init_loss_scaling=2.0**16,
                                       incr_every_n_steps=2000,
                                       decr_every_n_nan_or_inf=1)

    best = 0.
    for epoch in range(0, cfg.epochs):
        if epoch < resume_epoch:
            logger.info(
                f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
            )
            continue
        model.train()

        record_list = build_record(cfg.MODEL)
        tic = time.time()
        for i, data in enumerate(train_loader):
            record_list['reader_time'].update(time.time() - tic)

            # 4.1 forward

            ###AMP###
            if amp:
                with paddle.amp.auto_cast(custom_black_list={"reduce_mean"}):
                    outputs = model(data, mode='train')

                avg_loss = outputs['loss']
                scaled = scaler.scale(avg_loss)
                scaled.backward()
                # keep prior to 2.0 design
                scaler.minimize(optimizer, scaled)
                optimizer.clear_grad()

            else:
                outputs = model(data, mode='train')

                # 4.2 backward
                if use_gradient_accumulation and i == 0:  # Use gradient accumulation strategy
                    optimizer.clear_grad()
                avg_loss = outputs['loss']
                avg_loss.backward()

                # 4.3 minimize
                if use_gradient_accumulation:  # Use gradient accumulation strategy
                    if (i + 1) % cfg.GRADIENT_ACCUMULATION.num_iters == 0:
                        for p in model.parameters():
                            p.grad.set_value(
                                p.grad / cfg.GRADIENT_ACCUMULATION.num_iters)
                        optimizer.step()
                        optimizer.clear_grad()
                else:  # Common case
                    optimizer.step()
                    optimizer.clear_grad()

            # log record
            record_list['lr'].update(optimizer.get_lr(), batch_size)
            for name, value in outputs.items():
                record_list[name].update(value, batch_size)

            record_list['batch_time'].update(time.time() - tic)
            tic = time.time()

            if i % cfg.get("log_interval", 10) == 0:
                ips = "ips: {:.5f} instance/sec.".format(
                    batch_size / record_list["batch_time"].val)
                log_batch(record_list, i, epoch + 1, cfg.epochs, "train", ips)

            # learning rate iter step
            if cfg.OPTIMIZER.learning_rate.get("iter_step"):
                lr.step()

        # learning rate epoch step
        if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
            lr.step()

        ips = "avg_ips: {:.5f} instance/sec.".format(
            batch_size * record_list["batch_time"].count /
            record_list["batch_time"].sum)
        log_epoch(record_list, epoch + 1, "train", ips)

        def evaluate(best):
            model.eval()
            record_list = build_record(cfg.MODEL)
            record_list.pop('lr')
            tic = time.time()
            for i, data in enumerate(valid_loader):
                outputs = model(data, mode='valid')

                # log_record
                for name, value in outputs.items():
                    record_list[name].update(value, batch_size)

                record_list['batch_time'].update(time.time() - tic)
                tic = time.time()

                if i % cfg.get("log_interval", 10) == 0:
                    ips = "ips: {:.5f} instance/sec.".format(
                        batch_size / record_list["batch_time"].val)
                    log_batch(record_list, i, epoch + 1, cfg.epochs, "val",
                              ips)

            ips = "avg_ips: {:.5f} instance/sec.".format(
                batch_size * record_list["batch_time"].count /
                record_list["batch_time"].sum)
            log_epoch(record_list, epoch + 1, "val", ips)

            best_flag = False
            for top_flag in ['hit_at_one', 'top1']:
                if record_list.get(
                        top_flag) and record_list[top_flag].avg > best:
                    best = record_list[top_flag].avg
                    best_flag = True

            return best, best_flag

        # use precise bn to improve acc
        if cfg.get("PRECISEBN") and (epoch % cfg.PRECISEBN.preciseBN_interval
                                     == 0 or epoch == cfg.epochs - 1):
            do_preciseBN(
                model, train_loader, parallel,
                min(cfg.PRECISEBN.num_iters_preciseBN, len(train_loader)))

        # 5. Validation
        if validate and (epoch % cfg.get("val_interval", 1) == 0
                         or epoch == cfg.epochs - 1):
            with paddle.no_grad():
                best, save_best_flag = evaluate(best)
            # save best
            if save_best_flag:
                save(optimizer.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdopt"))
                save(model.state_dict(),
                     osp.join(output_dir, model_name + "_best.pdparams"))
                if model_name == "AttentionLstm":
                    logger.info(
                        f"Already save the best model (hit_at_one){best}")
                else:
                    logger.info(
                        f"Already save the best model (top1 acc){int(best *10000)/10000}"
                    )

        # 6. Save model and optimizer
        if epoch % cfg.get("save_interval", 1) == 0 or epoch == cfg.epochs - 1:
            save(
                optimizer.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdopt"))
            save(
                model.state_dict(),
                osp.join(output_dir,
                         model_name + f"_epoch_{epoch+1:05d}.pdparams"))

    logger.info(f'training {model_name} finished')