Пример #1
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None):
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        # Initialize parallel training environment.
        paddle.distributed.init_parallel_env()
        strategy = paddle.distributed.prepare_context()
        ddp_model = paddle.DataParallel(model, strategy)

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    timer = Timer()
    avg_loss = 0.0
    iters_per_epoch = len(batch_sampler)
    best_mean_iou = -1.0
    best_model_iter = -1
    train_reader_cost = 0.0
    train_batch_cost = 0.0
    timer.start()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            train_reader_cost += timer.elapsed_time()
            images = data[0]
            labels = data[1].astype('int64')
            if nranks > 1:
                logits = ddp_model(images)
                loss = loss_computation(logits, labels, losses)
                loss.backward()
            else:
                logits = model(images)
                loss = loss_computation(logits, labels, losses)
                loss.backward()
            optimizer.step()
            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()
            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            train_batch_cost += timer.elapsed_time()
            if (iter) % log_iters == 0 and local_rank == 0:
                avg_loss /= log_iters
                avg_train_reader_cost = train_reader_cost / log_iters
                avg_train_batch_cost = train_batch_cost / log_iters
                train_reader_cost = 0.0
                train_batch_cost = 0.0
                remain_iters = iters - iter
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss, lr, avg_train_batch_cost,
                            avg_train_reader_cost, eta))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, iter)
                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)
                avg_loss = 0.0

            if (iter % save_interval == 0
                    or iter == iters) and (val_dataset is not None):
                num_workers = 1 if num_workers > 0 else 0
                mean_iou, acc = evaluate(
                    model, val_dataset, num_workers=num_workers)
                model.train()

            if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))

                if val_dataset is not None:
                    if mean_iou > best_mean_iou:
                        best_mean_iou = mean_iou
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
                        .format(best_mean_iou, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
                        log_writer.add_scalar('Evaluate/Acc', acc, iter)
            timer.restart()

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #2
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None,
          keep_checkpoint_max=5,
          eval_begin_iters=None):
    """
    Launch training.
    Args:
        model(nn.Layer): A matting model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
        eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
    """
    model.train()
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        # Initialize parallel environment if not done.
        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
        ):
            paddle.distributed.init_parallel_env()
            ddp_model = paddle.DataParallel(model)
        else:
            ddp_model = paddle.DataParallel(model)

    batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    avg_loss = defaultdict(float)
    iters_per_epoch = len(batch_sampler)
    best_sad = np.inf
    best_model_iter = -1
    reader_cost_averager = TimeAverager()
    batch_cost_averager = TimeAverager()
    save_models = deque()
    batch_start = time.time()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            reader_cost_averager.record(time.time() - batch_start)

            # model input
            if nranks > 1:
                logit_dict = ddp_model(data)
            else:
                logit_dict = model(data)
            loss_dict = model.loss(logit_dict, data, losses)

            loss_dict['all'].backward()

            optimizer.step()
            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()
            model.clear_gradients()

            for key, value in loss_dict.items():
                avg_loss[key] += value.numpy()[0]
            batch_cost_averager.record(time.time() - batch_start,
                                       num_samples=batch_size)

            if (iter) % log_iters == 0 and local_rank == 0:
                for key, value in avg_loss.items():
                    avg_loss[key] = value / log_iters
                remain_iters = iters - iter
                avg_train_batch_cost = batch_cost_averager.get_average()
                avg_train_reader_cost = reader_cost_averager.get_average()
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss['all'], lr, avg_train_batch_cost,
                            avg_train_reader_cost,
                            batch_cost_averager.get_ips_average(), eta))
                # print loss
                loss_str = '[TRAIN] [LOSS] '
                loss_str = loss_str + 'all={:.4f}'.format(avg_loss['all'])
                for key, value in avg_loss.items():
                    if key != 'all':
                        loss_str = loss_str + ' ' + key + '={:.4f}'.format(
                            value)
                logger.info(loss_str)
                if use_vdl:
                    for key, value in avg_loss.items():
                        log_tag = 'Train/' + key
                        log_writer.add_scalar(log_tag, value, iter)

                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)

                for key in avg_loss.keys():
                    avg_loss[key] = 0.
                reader_cost_averager.reset()
                batch_cost_averager.reset()

            # save model
            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                save_models.append(current_save_dir)
                if len(save_models) > keep_checkpoint_max > 0:
                    model_to_remove = save_models.popleft()
                    shutil.rmtree(model_to_remove)

            # eval model
            if eval_begin_iters is None:
                eval_begin_iters = iters // 2
            if (iter % save_interval == 0 or iter == iters) and (
                    val_dataset is not None
            ) and local_rank == 0 and iter >= eval_begin_iters:
                num_workers = 1 if num_workers > 0 else 0
                sad, mse = evaluate(model,
                                    val_dataset,
                                    num_workers=0,
                                    print_detail=True,
                                    save_results=False)
                model.train()

            # save best model and add evaluation results to vdl
            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                if val_dataset is not None and iter >= eval_begin_iters:
                    if sad < best_sad:
                        best_sad = sad
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation sad ({:.4f}) was saved at iter {}.'
                        .format(best_sad, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/SAD', sad, iter)
                        log_writer.add_scalar('Evaluate/MSE', mse, iter)

            batch_start = time.time()

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #3
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None):
    """
    Launch training.

    Args:
        model(nn.Layer): A sementic segmentation model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
            The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
    """
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        # Initialize parallel training environment.
        paddle.distributed.init_parallel_env()
        strategy = paddle.distributed.prepare_context()
        ddp_model = paddle.DataParallel(model, strategy)

    batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    timer = Timer()
    avg_loss = 0.0
    iters_per_epoch = len(batch_sampler)
    best_mean_iou = -1.0
    best_model_iter = -1
    train_reader_cost = 0.0
    train_batch_cost = 0.0
    timer.start()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            train_reader_cost += timer.elapsed_time()
            images = data[0]
            labels = data[1].astype('int64')
            edges = None
            if len(data) == 3:
                edges = data[2].astype('int64')

            if nranks > 1:
                logits_list = ddp_model(images)
            else:
                logits_list = model(images)
            loss = loss_computation(logits_list=logits_list,
                                    labels=labels,
                                    losses=losses,
                                    edges=edges)
            loss.backward()

            optimizer.step()
            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()
            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            train_batch_cost += timer.elapsed_time()

            if (iter) % log_iters == 0 and local_rank == 0:
                avg_loss /= log_iters
                avg_train_reader_cost = train_reader_cost / log_iters
                avg_train_batch_cost = train_batch_cost / log_iters
                train_reader_cost = 0.0
                train_batch_cost = 0.0
                remain_iters = iters - iter
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss, lr, avg_train_batch_cost,
                            avg_train_reader_cost, eta))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, iter)
                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)
                avg_loss = 0.0

            if (iter % save_interval == 0 or iter == iters) and (val_dataset
                                                                 is not None):
                num_workers = 1 if num_workers > 0 else 0
                mean_iou, acc = evaluate(model,
                                         val_dataset,
                                         num_workers=num_workers)
                model.train()

            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))

                if val_dataset is not None:
                    if mean_iou > best_mean_iou:
                        best_mean_iou = mean_iou
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
                        .format(best_mean_iou, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
                        log_writer.add_scalar('Evaluate/Acc', acc, iter)
            timer.restart()

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #4
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None,
          keep_checkpoint_max=5,
          test_config=None,
          fp16=False,
          profiler_options=None):
    """
    Launch training.

    Args:
        model(nn.Layer): A sementic segmentation model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict, optional): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
            The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
        test_config(dict, optional): Evaluation config.
        fp16 (bool, optional): Whether to use amp.
        profiler_options (str, optional): The option of train profiler.
    """
    model.train()
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        paddle.distributed.fleet.init(is_collective=True)
        optimizer = paddle.distributed.fleet.distributed_optimizer(
            optimizer)  # The return is Fleet object
        ddp_model = paddle.distributed.fleet.distributed_model(model)

    batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
        worker_init_fn=worker_init_fn,
    )

    # use amp
    if fp16:
        logger.info('use amp to train')
        scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    avg_loss = 0.0
    avg_loss_list = []
    iters_per_epoch = len(batch_sampler)
    best_acc = -1.0
    best_model_iter = -1
    reader_cost_averager = TimeAverager()
    batch_cost_averager = TimeAverager()
    save_models = deque()
    batch_start = time.time()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                version = paddle.__version__
                if version == '2.1.2':
                    continue
                else:
                    break
            reader_cost_averager.record(time.time() - batch_start)
            images = data[0]
            labels = data[1].astype('int64')
            edges = None
            if len(data) == 3:
                edges = data[2].astype('int64')
            if hasattr(model, 'data_format') and model.data_format == 'NHWC':
                images = images.transpose((0, 2, 3, 1))

            if fp16:
                with paddle.amp.auto_cast(
                        enable=True,
                        custom_white_list={
                            "elementwise_add", "batch_norm", "sync_batch_norm"
                        },
                        custom_black_list={'bilinear_interp_v2'}):
                    if nranks > 1:
                        logits_list = ddp_model(images)
                    else:
                        logits_list = model(images)
                    loss_list = loss_computation(logits_list=logits_list,
                                                 labels=labels,
                                                 losses=losses,
                                                 edges=edges)
                    loss = sum(loss_list)

                scaled = scaler.scale(loss)  # scale the loss
                scaled.backward()  # do backward
                if isinstance(optimizer, paddle.distributed.fleet.Fleet):
                    scaler.minimize(optimizer.user_defined_optimizer, scaled)
                else:
                    scaler.minimize(optimizer, scaled)  # update parameters
            else:
                if nranks > 1:
                    logits_list = ddp_model(images)
                else:
                    logits_list = model(images)
                loss_list = loss_computation(logits_list=logits_list,
                                             labels=labels,
                                             losses=losses,
                                             edges=edges)
                loss = sum(loss_list)
                loss.backward()
                optimizer.step()

            lr = optimizer.get_lr()

            # update lr
            if isinstance(optimizer, paddle.distributed.fleet.Fleet):
                lr_sche = optimizer.user_defined_optimizer._learning_rate
            else:
                lr_sche = optimizer._learning_rate
            if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
                lr_sche.step()

            train_profiler.add_profiler_step(profiler_options)

            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            if not avg_loss_list:
                avg_loss_list = [l.numpy() for l in loss_list]
            else:
                for i in range(len(loss_list)):
                    avg_loss_list[i] += loss_list[i].numpy()
            batch_cost_averager.record(time.time() - batch_start,
                                       num_samples=batch_size)

            if (iter) % log_iters == 0 and local_rank == 0:
                avg_loss /= log_iters
                avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
                remain_iters = iters - iter
                avg_train_batch_cost = batch_cost_averager.get_average()
                avg_train_reader_cost = reader_cost_averager.get_average()
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss, lr, avg_train_batch_cost,
                            avg_train_reader_cost,
                            batch_cost_averager.get_ips_average(), eta))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, iter)
                    # Record all losses if there are more than 2 losses.
                    if len(avg_loss_list) > 1:
                        avg_loss_dict = {}
                        for i, value in enumerate(avg_loss_list):
                            avg_loss_dict['loss_' + str(i)] = value
                        for key, value in avg_loss_dict.items():
                            log_tag = 'Train/' + key
                            log_writer.add_scalar(log_tag, value, iter)

                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)
                avg_loss = 0.0
                avg_loss_list = []
                reader_cost_averager.reset()
                batch_cost_averager.reset()

            if (iter % save_interval == 0 or iter == iters) and (val_dataset
                                                                 is not None):
                num_workers = 1 if num_workers > 0 else 0

                if test_config is None:
                    test_config = {}

                acc, fp, fn = evaluate(model,
                                       val_dataset,
                                       num_workers=num_workers,
                                       save_dir=save_dir,
                                       **test_config)

                model.train()

            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                save_models.append(current_save_dir)
                if len(save_models) > keep_checkpoint_max > 0:
                    model_to_remove = save_models.popleft()
                    shutil.rmtree(model_to_remove)

                if val_dataset is not None:
                    if acc > best_acc:
                        best_acc = acc
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation Acc ({:.4f}) was saved at iter {}.'
                        .format(best_acc, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/Acc', acc, iter)
                        log_writer.add_scalar('Evaluate/Fp', fp, iter)
                        log_writer.add_scalar('Evaluate/Fn', fn, iter)
            batch_start = time.time()

    # Calculate flops.
    if local_rank == 0:
        _, c, h, w = images.shape
        _ = paddle.flops(
            model, [1, c, h, w],
            custom_ops={paddle.nn.SyncBatchNorm: op_flops_funs.count_syncbn})

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #5
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None,
          keep_checkpoint_max=5,
          threshold=0.1,
          nms_kernel=7,
          top_k=200):
    """
    Launch training.

    Args:
        model(nn.Layer): A sementic segmentation model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
            The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
        threshold (float, optional): A Float, threshold applied to center heatmap score. Default: 0.1.
        nms_kernel (int, optional): An Integer, NMS max pooling kernel size. Default: 7.
        top_k (int, optional): An Integer, top k centers to keep. Default: 200.
    """
    model.train()
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        # Initialize parallel environment if not done.
        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
        ):
            paddle.distributed.init_parallel_env()
            ddp_model = paddle.DataParallel(model)
        else:
            ddp_model = paddle.DataParallel(model)

    batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    avg_loss = 0.0
    avg_loss_list = []
    iters_per_epoch = len(batch_sampler)
    best_pq = -1.0
    best_model_iter = -1
    reader_cost_averager = TimeAverager()
    batch_cost_averager = TimeAverager()
    save_models = deque()
    batch_start = time.time()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            reader_cost_averager.record(time.time() - batch_start)
            images = data[0]
            semantic = data[1]
            semantic_weights = data[2]
            center = data[3]
            center_weights = data[4]
            offset = data[5]
            offset_weights = data[6]
            foreground = data[7]

            if nranks > 1:
                logits_list = ddp_model(images)
            else:
                logits_list = model(images)

            loss_list = loss_computation(logits_list=logits_list,
                                         losses=losses,
                                         semantic=semantic,
                                         semantic_weights=semantic_weights,
                                         center=center,
                                         center_weights=center_weights,
                                         offset=offset,
                                         offset_weights=offset_weights)
            loss = sum(loss_list)
            loss.backward()

            optimizer.step()
            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()
            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            if not avg_loss_list:
                avg_loss_list = [l.numpy() for l in loss_list]
            else:
                for i in range(len(loss_list)):
                    avg_loss_list[i] += loss_list[i].numpy()
            batch_cost_averager.record(time.time() - batch_start,
                                       num_samples=batch_size)

            if (iter) % log_iters == 0 and local_rank == 0:
                avg_loss /= log_iters
                avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
                remain_iters = iters - iter
                avg_train_batch_cost = batch_cost_averager.get_average()
                avg_train_reader_cost = reader_cost_averager.get_average()
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss, lr, avg_train_batch_cost,
                            avg_train_reader_cost,
                            batch_cost_averager.get_ips_average(), eta))
                logger.info(
                    "[LOSS] loss={:.4f}, semantic_loss={:.4f}, center_loss={:.4f}, offset_loss={:.4f}"
                    .format(avg_loss, avg_loss_list[0], avg_loss_list[1],
                            avg_loss_list[2]))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, iter)
                    # Record all losses if there are more than 2 losses.
                    if len(avg_loss_list) > 1:
                        avg_loss_dict = {}
                        for i, value in enumerate(avg_loss_list):
                            avg_loss_dict['loss_' + str(i)] = value
                        for key, value in avg_loss_dict.items():
                            log_tag = 'Train/' + key
                            log_writer.add_scalar(log_tag, value, iter)

                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)

                avg_loss = 0.0
                avg_loss_list = []
                reader_cost_averager.reset()
                batch_cost_averager.reset()

            # save model
            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                save_models.append(current_save_dir)
                if len(save_models) > keep_checkpoint_max > 0:
                    model_to_remove = save_models.popleft()
                    shutil.rmtree(model_to_remove)

            # eval model
            if (iter % save_interval == 0 or iter == iters) and (
                    val_dataset
                    is not None) and local_rank == 0 and iter > iters // 2:
                num_workers = 1 if num_workers > 0 else 0
                panoptic_results, semantic_results, instance_results = evaluate(
                    model,
                    val_dataset,
                    threshold=threshold,
                    nms_kernel=nms_kernel,
                    top_k=top_k,
                    num_workers=num_workers,
                    print_detail=False)
                pq = panoptic_results['pan_seg']['All']['pq']
                miou = semantic_results['sem_seg']['mIoU']
                map = instance_results['ins_seg']['mAP']
                map50 = instance_results['ins_seg']['mAP50']
                logger.info(
                    "[EVAL] PQ: {:.4f}, mIoU: {:.4f}, mAP: {:.4f}, mAP50: {:.4f}"
                    .format(pq, miou, map, map50))
                model.train()

            # save best model and add evaluate results to vdl
            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                if val_dataset is not None and iter > iters // 2:
                    if pq > best_pq:
                        best_pq = pq
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation pq ({:.4f}) was saved at iter {}.'
                        .format(best_pq, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/PQ', pq, iter)
                        log_writer.add_scalar('Evaluate/mIoU', miou, iter)
                        log_writer.add_scalar('Evaluate/mAP', map, iter)
                        log_writer.add_scalar('Evaluate/mAP50', map50, iter)
            batch_start = time.time()

    # Calculate flops.
    if local_rank == 0:

        def count_syncbn(m, x, y):
            x = x[0]
            nelements = x.numel()
            m.total_ops += int(2 * nelements)

        _, c, h, w = images.shape
        flops = paddle.flops(
            model, [1, c, h, w],
            custom_ops={paddle.nn.SyncBatchNorm: count_syncbn})

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #6
0
def train(model,
          train_dataset,
          val_dataset=None,
          optimizer=None,
          save_dir='output',
          iters=10000,
          batch_size=2,
          resume_model=None,
          save_interval=1000,
          log_iters=10,
          num_workers=0,
          use_vdl=False,
          losses=None,
          keep_checkpoint_max=5):
    """
    Launch training.
    Args:
        model(nn.Layer): A sementic segmentation model.
        train_dataset (paddle.io.Dataset): Used to read and process training datasets.
        val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
        optimizer (paddle.optimizer.Optimizer): The optimizer.
        save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
        iters (int, optional): How may iters to train the model. Defualt: 10000.
        batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
        resume_model (str, optional): The path of resume model.
        save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
        log_iters (int, optional): Display logging information at every log_iters. Default: 10.
        num_workers (int, optional): Num workers for data loader. Default: 0.
        use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
        losses (dict): A dict including 'types' and 'coef'. The length of coef should equal to 1 or len(losses['types']).
            The 'types' item is a list of object of paddleseg.models.losses while the 'coef' item is a list of the relevant coefficient.
        keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
    """
    model.train()
    nranks = paddle.distributed.ParallelEnv().nranks
    local_rank = paddle.distributed.ParallelEnv().local_rank

    start_iter = 0
    if resume_model is not None:
        start_iter = resume(model, optimizer, resume_model)

    if not os.path.isdir(save_dir):
        if os.path.exists(save_dir):
            os.remove(save_dir)
        os.makedirs(save_dir)

    if nranks > 1:
        # Initialize parallel environment if not done.
        if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
        ):
            paddle.distributed.init_parallel_env()
            ddp_model = paddle.DataParallel(model)
        else:
            ddp_model = paddle.DataParallel(model)

    batch_sampler = paddle.io.DistributedBatchSampler(train_dataset,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

    loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        return_list=True,
    )

    if use_vdl:
        from visualdl import LogWriter
        log_writer = LogWriter(save_dir)

    avg_loss = 0.0
    avg_loss_list = []
    iters_per_epoch = len(batch_sampler)
    best_mean_iou = -1.0
    best_model_iter = -1
    reader_cost_averager = TimeAverager()
    batch_cost_averager = TimeAverager()
    save_models = deque()
    batch_start = time.time()

    iter = start_iter
    while iter < iters:
        for data in loader:
            iter += 1
            if iter > iters:
                break
            reader_cost_averager.record(time.time() - batch_start)
            images = data[0]
            labels = data[1].astype('int64')
            edges = None
            if len(data) == 3:
                edges = data[2].astype('int64')

            if nranks > 1:
                logits_list = ddp_model(images)
            else:
                logits_list = model(images)
            loss_list = loss_computation(logits_list=logits_list,
                                         labels=labels,
                                         losses=losses,
                                         edges=edges)
            loss = sum(loss_list)
            loss.backward()

            optimizer.step()
            lr = optimizer.get_lr()
            if isinstance(optimizer._learning_rate,
                          paddle.optimizer.lr.LRScheduler):
                optimizer._learning_rate.step()
            model.clear_gradients()
            avg_loss += loss.numpy()[0]
            if not avg_loss_list:
                avg_loss_list = [l.numpy() for l in loss_list]
            else:
                for i in range(len(loss_list)):
                    avg_loss_list[i] += loss_list[i].numpy()
            batch_cost_averager.record(time.time() - batch_start,
                                       num_samples=batch_size)

            if (iter) % log_iters == 0 and local_rank == 0:
                avg_loss /= log_iters
                avg_loss_list = [l[0] / log_iters for l in avg_loss_list]
                remain_iters = iters - iter
                avg_train_batch_cost = batch_cost_averager.get_average()
                avg_train_reader_cost = reader_cost_averager.get_average()
                eta = calculate_eta(remain_iters, avg_train_batch_cost)
                logger.info(
                    "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
                    .format((iter - 1) // iters_per_epoch + 1, iter, iters,
                            avg_loss, lr, avg_train_batch_cost,
                            avg_train_reader_cost,
                            batch_cost_averager.get_ips_average(), eta))
                if use_vdl:
                    log_writer.add_scalar('Train/loss', avg_loss, iter)
                    # Record all losses if there are more than 2 losses.
                    if len(avg_loss_list) > 1:
                        avg_loss_dict = {}
                        for i, value in enumerate(avg_loss_list):
                            avg_loss_dict['loss_' + str(i)] = value
                        for key, value in avg_loss_dict.items():
                            log_tag = 'Train/' + key
                            log_writer.add_scalar(log_tag, value, iter)

                    log_writer.add_scalar('Train/lr', lr, iter)
                    log_writer.add_scalar('Train/batch_cost',
                                          avg_train_batch_cost, iter)
                    log_writer.add_scalar('Train/reader_cost',
                                          avg_train_reader_cost, iter)
                avg_loss = 0.0
                avg_loss_list = []
                reader_cost_averager.reset()
                batch_cost_averager.reset()

            if (iter % save_interval == 0 or iter == iters) and (val_dataset
                                                                 is not None):
                num_workers = 1 if num_workers > 0 else 0
                mean_iou, acc, class_iou, _, _ = evaluate(
                    model, val_dataset, num_workers=num_workers)
                model.train()

            if (iter % save_interval == 0
                    or iter == iters) and local_rank == 0:
                current_save_dir = os.path.join(save_dir,
                                                "iter_{}".format(iter))
                if not os.path.isdir(current_save_dir):
                    os.makedirs(current_save_dir)
                paddle.save(model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                paddle.save(optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                save_models.append(current_save_dir)
                if len(save_models) > keep_checkpoint_max > 0:
                    model_to_remove = save_models.popleft()
                    shutil.rmtree(model_to_remove)

                if val_dataset is not None:
                    if mean_iou > best_mean_iou:
                        best_mean_iou = mean_iou
                        best_model_iter = iter
                        best_model_dir = os.path.join(save_dir, "best_model")
                        paddle.save(
                            model.state_dict(),
                            os.path.join(best_model_dir, 'model.pdparams'))
                    logger.info(
                        '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
                        .format(best_mean_iou, best_model_iter))

                    if use_vdl:
                        log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
                        for i, iou in enumerate(class_iou):
                            log_writer.add_scalar('Evaluate/IoU {}'.format(i),
                                                  float(iou), iter)

                        log_writer.add_scalar('Evaluate/Acc', acc, iter)
            batch_start = time.time()

    # Calculate flops.
    if local_rank == 0:

        def count_syncbn(m, x, y):
            x = x[0]
            nelements = x.numel()
            m.total_ops += int(2 * nelements)

        _, c, h, w = images.shape
        flops = paddle.flops(
            model, [1, c, h, w],
            custom_ops={paddle.nn.SyncBatchNorm: count_syncbn})

    # Sleep for half a second to let dataloader release resources.
    time.sleep(0.5)
    if use_vdl:
        log_writer.close()
Пример #7
0
    def train(self,
              train_dataset_src,
              train_dataset_tgt,
              val_dataset_tgt=None,
              val_dataset_src=None,
              optimizer=None,
              save_dir='output',
              iters=10000,
              batch_size=2,
              resume_model=None,
              save_interval=1000,
              log_iters=10,
              num_workers=0,
              use_vdl=False,
              keep_checkpoint_max=5,
              test_config=None):
        """
        Launch training.

        Args:
            train_dataset (paddle.io.Dataset): Used to read and process training datasets.
            val_dataset_tgt (paddle.io.Dataset, optional): Used to read and process validation datasets.
            optimizer (paddle.optimizer.Optimizer): The optimizer.
            save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
            iters (int, optional): How may iters to train the model. Defualt: 10000.
            batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
            resume_model (str, optional): The path of resume model.
            save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
            log_iters (int, optional): Display logging information at every log_iters. Default: 10.
            num_workers (int, optional): Num workers for data loader. Default: 0.
            use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
            keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
            test_config(dict, optional): Evaluation config.
        """
        start_iter = 0
        self.model.train()
        nranks = paddle.distributed.ParallelEnv().nranks
        local_rank = paddle.distributed.ParallelEnv().local_rank

        if resume_model is not None:
            logger.info(resume_model)
            start_iter = resume(self.model, optimizer, resume_model)
        load_ema_model(self.model, self.resume_ema)

        if not os.path.isdir(save_dir):
            if os.path.exists(save_dir):
                os.remove(save_dir)
            os.makedirs(save_dir)

        if nranks > 1:
            paddle.distributed.fleet.init(is_collective=True)
            optimizer = paddle.distributed.fleet.distributed_optimizer(
                optimizer)  # The return is Fleet object
            ddp_model = paddle.distributed.fleet.distributed_model(self.model)

        batch_sampler_src = paddle.io.DistributedBatchSampler(
            train_dataset_src,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        loader_src = paddle.io.DataLoader(
            train_dataset_src,
            batch_sampler=batch_sampler_src,
            num_workers=num_workers,
            return_list=True,
            worker_init_fn=worker_init_fn,
        )
        batch_sampler_tgt = paddle.io.DistributedBatchSampler(
            train_dataset_tgt,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        loader_tgt = paddle.io.DataLoader(
            train_dataset_tgt,
            batch_sampler=batch_sampler_tgt,
            num_workers=num_workers,
            return_list=True,
            worker_init_fn=worker_init_fn,
        )

        if use_vdl:
            from visualdl import LogWriter
            log_writer = LogWriter(save_dir)

        iters_per_epoch = len(batch_sampler_tgt)
        best_mean_iou = -1.0
        best_model_iter = -1
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
        save_models = deque()
        batch_start = time.time()

        iter = start_iter
        while iter < iters:
            for _, (data_src,
                    data_tgt) in enumerate(zip(loader_src, loader_tgt)):

                reader_cost_averager.record(time.time() - batch_start)
                loss_dict = {}

                #### training #####
                images_tgt = data_tgt[0]
                labels_tgt = data_tgt[1].astype('int64')
                images_src = data_src[0]
                labels_src = data_src[1].astype('int64')

                edges_src = data_src[2].astype('int64')
                edges_tgt = data_tgt[2].astype('int64')

                if nranks > 1:
                    logits_list_src = ddp_model(images_src)
                else:
                    logits_list_src = self.model(images_src)

                ##### source seg & edge loss ####
                loss_src_seg_main = self.celoss(logits_list_src[0], labels_src)
                loss_src_seg_aux = 0.1 * self.celoss(logits_list_src[1],
                                                     labels_src)

                loss_src_seg = loss_src_seg_main + loss_src_seg_aux
                loss_dict["source_main"] = loss_src_seg_main.numpy()[0]
                loss_dict["source_aux"] = loss_src_seg_aux.numpy()[0]
                loss = loss_src_seg
                del loss_src_seg, loss_src_seg_aux, loss_src_seg_main

                #### generate target pseudo label  ####
                with paddle.no_grad():
                    if nranks > 1:
                        logits_list_tgt = ddp_model(images_tgt)
                    else:
                        logits_list_tgt = self.model(images_tgt)

                    pred_P_1 = F.softmax(logits_list_tgt[0], axis=1)
                    labels_tgt_psu = paddle.argmax(pred_P_1.detach(), axis=1)

                    # aux label
                    pred_P_2 = F.softmax(logits_list_tgt[1], axis=1)
                    pred_c = (pred_P_1 + pred_P_2) / 2
                    labels_tgt_psu_aux = paddle.argmax(pred_c.detach(), axis=1)

                if self.edgeconstrain:
                    loss_src_edge = self.bceloss_src(
                        logits_list_src[2], edges_src)  # 1, 2 640, 1280
                    src_edge = paddle.argmax(
                        logits_list_src[2].detach().clone(),
                        axis=1)  # 1, 1, 640,1280
                    src_edge_acc = ((src_edge == edges_src).numpy().sum().astype('float32')\
                                                /functools.reduce(lambda a, b: a * b, src_edge.shape))*100

                    if (not self.src_only) and (iter > 200000):
                        ####  target seg & edge loss ####
                        logger.info("Add target edege loss")
                        edges_tgt = Func.mask_to_binary_edge(
                            labels_tgt_psu.detach().clone().numpy(),
                            radius=2,
                            num_classes=train_dataset_tgt.NUM_CLASSES)
                        edges_tgt = paddle.to_tensor(edges_tgt, dtype='int64')

                        loss_tgt_edge = self.bceloss_tgt(
                            logits_list_tgt[2], edges_tgt)
                        loss_edge = loss_tgt_edge + loss_src_edge
                    else:
                        loss_tgt_edge = paddle.zeros([1])
                        loss_edge = loss_src_edge

                    loss += loss_edge

                    loss_dict['target_edge'] = loss_tgt_edge.numpy()[0]
                    loss_dict['source_edge'] = loss_src_edge.numpy()[0]

                    del loss_edge, loss_tgt_edge, loss_src_edge

                #### target aug loss #######
                augs = augmentation.get_augmentation()
                images_tgt_aug, labels_tgt_aug = augmentation.augment(
                    images=images_tgt.cpu(),
                    labels=labels_tgt_psu.detach().cpu(),
                    aug=augs,
                    iters="{}_1".format(iter))
                images_tgt_aug = images_tgt_aug.cuda()
                labels_tgt_aug = labels_tgt_aug.cuda()

                _, labels_tgt_aug_aux = augmentation.augment(
                    images=images_tgt.cpu(),
                    labels=labels_tgt_psu_aux.detach().cpu(),
                    aug=augs,
                    iters="{}_2".format(iter))
                labels_tgt_aug_aux = labels_tgt_aug_aux.cuda()

                if nranks > 1:
                    logits_list_tgt_aug = ddp_model(images_tgt_aug)
                else:
                    logits_list_tgt_aug = self.model(images_tgt_aug)

                loss_tgt_aug_main = 0.1 * (self.celoss(logits_list_tgt_aug[0],
                                                       labels_tgt_aug))
                loss_tgt_aug_aux = 0.1 * (0.1 * self.celoss(
                    logits_list_tgt_aug[1], labels_tgt_aug_aux))

                loss_tgt_aug = loss_tgt_aug_aux + loss_tgt_aug_main

                loss += loss_tgt_aug

                loss_dict['target_aug_main'] = loss_tgt_aug_main.numpy()[0]
                loss_dict['target_aug_aux'] = loss_tgt_aug_aux.numpy()[0]
                del images_tgt_aug, labels_tgt_aug_aux, images_tgt, \
                    loss_tgt_aug, loss_tgt_aug_aux, loss_tgt_aug_main

                #### edge input seg; src & tgt edge pull in ######
                if self.edgepullin:
                    src_edge_logit = logits_list_src[2]
                    feat_src = paddle.concat(
                        [logits_list_src[0], src_edge_logit], axis=1).detach()

                    out_src = self.model.fusion(feat_src)
                    loss_src_edge_rec = self.celoss(out_src, labels_src)

                    tgt_edge_logit = logits_list_tgt_aug[2]
                    # tgt_edge_logit = paddle.to_tensor(
                    #     Func.mask_to_onehot(edges_tgt.squeeze().numpy(), 2)
                    #     ).unsqueeze(0).astype('float32')
                    feat_tgt = paddle.concat(
                        [logits_list_tgt[0], tgt_edge_logit], axis=1).detach()

                    out_tgt = self.model.fusion(feat_tgt)
                    loss_tgt_edge_rec = self.celoss(out_tgt, labels_tgt)

                    loss_edge_rec = loss_tgt_edge_rec + loss_src_edge_rec
                    loss += loss_edge_rec

                    loss_dict['src_edge_rec'] = loss_src_edge_rec.numpy()[0]
                    loss_dict['tgt_edge_rec'] = loss_tgt_edge_rec.numpy()[0]

                    del loss_tgt_edge_rec, loss_src_edge_rec

                #### mask input feature & pullin  ######
                if self.featurepullin:
                    # inner-class loss
                    feat_src = logits_list_src[0]
                    feat_tgt = logits_list_tgt_aug[0]
                    center_src_s, center_tgt_s = [], []

                    total_pixs = logits_list_src[0].shape[2] * \
                                    logits_list_src[0].shape[3]

                    for i in range(train_dataset_tgt.NUM_CLASSES):
                        pred = paddle.argmax(
                            logits_list_src[0].detach().clone(),
                            axis=1).unsqueeze(0)  # 1, 1, 640, 1280
                        sel_num = paddle.sum((pred == i).astype('float32'))
                        # ignore tensor that do not have features in this img
                        if sel_num > 0:
                            feat_sel_src = paddle.where(
                                (pred == i).expand_as(feat_src), feat_src,
                                paddle.zeros(feat_src.shape))
                            center_src = paddle.mean(feat_sel_src, axis=[
                                2, 3
                            ]) / (sel_num / total_pixs)  # 1, C

                            self.src_centers[i] = 0.99 * self.src_centers[
                                i] + (1 - 0.99) * center_src

                        pred = labels_tgt_aug.unsqueeze(0)  # 1, 1,  512, 512
                        sel_num = paddle.sum((pred == i).astype('float32'))
                        if sel_num > 0:
                            feat_sel_tgt = paddle.where(
                                (pred == i).expand_as(feat_tgt), feat_tgt,
                                paddle.zeros(feat_tgt.shape))
                            center_tgt = paddle.mean(feat_sel_tgt, axis=[
                                2, 3
                            ]) / (sel_num / total_pixs)

                            self.tgt_centers[i] = 0.99 * self.tgt_centers[
                                i] + (1 - 0.99) * center_tgt
                        center_src_s.append(center_src)
                        center_tgt_s.append(center_tgt)

                    if iter >= 3000:  # average center structure alignment
                        src_centers = paddle.concat(self.src_centers, axis=0)
                        tgt_centers = paddle.concat(self.tgt_centers,
                                                    axis=0)  # 19, 2048

                        relatmat_src = paddle.matmul(src_centers,
                                                     src_centers,
                                                     transpose_y=True)  # 19,19
                        relatmat_tgt = paddle.matmul(tgt_centers,
                                                     tgt_centers,
                                                     transpose_y=True)

                        loss_intra_relate = self.klloss(relatmat_src, (relatmat_tgt+relatmat_src)/2) \
                                            + self.klloss(relatmat_tgt, (relatmat_tgt+relatmat_src)/2)

                        loss_pix_align_src = self.mseloss(
                            paddle.to_tensor(center_src_s),
                            paddle.to_tensor(
                                self.src_centers).detach().clone())
                        loss_pix_align_tgt = self.mseloss(
                            paddle.to_tensor(center_tgt_s),
                            paddle.to_tensor(
                                self.tgt_centers).detach().clone())

                        loss_feat_align = loss_pix_align_src + loss_pix_align_tgt + loss_intra_relate

                        loss += loss_feat_align

                        loss_dict['loss_pix_align_src'] = \
                                loss_pix_align_src.numpy()[0]
                        loss_dict['loss_pix_align_tgt'] = \
                                loss_pix_align_tgt.numpy()[0]
                        loss_dict['loss_intra_relate'] = \
                                loss_intra_relate.numpy()[0]

                        del loss_pix_align_tgt, loss_pix_align_src, loss_intra_relate,

                    self.tgt_centers = [
                        item.detach().clone() for item in self.tgt_centers
                    ]
                    self.src_centers = [
                        item.detach().clone() for item in self.src_centers
                    ]

                loss.backward()
                del loss
                loss = sum(loss_dict.values())

                optimizer.step()
                self.ema.update_params()

                with paddle.no_grad():
                    ##### log & save #####
                    lr = optimizer.get_lr()
                    # update lr
                    if isinstance(optimizer, paddle.distributed.fleet.Fleet):
                        lr_sche = optimizer.user_defined_optimizer._learning_rate
                    else:
                        lr_sche = optimizer._learning_rate
                    if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
                        lr_sche.step()

                    if self.cfg['save_edge']:
                        tgt_edge = paddle.argmax(
                            logits_list_tgt_aug[2].detach().clone(),
                            axis=1)  # 1, 1, 640,1280
                        src_feed_gt = paddle.argmax(
                            src_edge_logit.astype('float32'), axis=1)
                        tgt_feed_gt = paddle.argmax(
                            tgt_edge_logit.astype('float32'), axis=1)
                        logger.info('src_feed_gt_{}_{}_{}'.format(
                            src_feed_gt.shape, src_feed_gt.max(),
                            src_feed_gt.min()))
                        logger.info('tgt_feed_gt_{}_{}_{}'.format(
                            tgt_feed_gt.shape, max(tgt_feed_gt),
                            min(tgt_feed_gt)))
                        save_edge(src_feed_gt, 'src_feed_gt_{}'.format(iter))
                        save_edge(tgt_feed_gt, 'tgt_feed_gt_{}'.format(iter))
                        save_edge(tgt_edge, 'tgt_pred_{}'.format(iter))
                        save_edge(src_edge,
                                  'src_pred_{}_{}'.format(iter, src_edge_acc))
                        save_edge(edges_src, 'src_gt_{}'.format(iter))
                        save_edge(edges_tgt, 'tgt_gt_{}'.format(iter))

                    self.model.clear_gradients()

                    batch_cost_averager.record(time.time() - batch_start,
                                               num_samples=batch_size)

                    iter += 1
                    if (iter) % log_iters == 0 and local_rank == 0:
                        label_tgt_acc = ((labels_tgt == labels_tgt_psu).numpy().sum().astype('float32')\
                                            /functools.reduce(lambda a, b: a * b, labels_tgt_psu.shape))*100

                        remain_iters = iters - iter
                        avg_train_batch_cost = batch_cost_averager.get_average(
                        )
                        avg_train_reader_cost = reader_cost_averager.get_average(
                        )
                        eta = calculate_eta(remain_iters, avg_train_batch_cost)
                        logger.info(
                            "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, tgt_pix_acc: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
                            .format(
                                (iter - 1) // iters_per_epoch + 1, iter, iters,
                                loss, label_tgt_acc, lr, avg_train_batch_cost,
                                avg_train_reader_cost,
                                batch_cost_averager.get_ips_average(), eta))

                        if use_vdl:
                            log_writer.add_scalar('Train/loss', loss, iter)
                            # Record all losses if there are more than 2 losses.
                            if len(loss_dict) > 1:
                                for name, loss in loss_dict.items():
                                    log_writer.add_scalar(
                                        'Train/loss_' + name, loss, iter)

                            log_writer.add_scalar('Train/lr', lr, iter)
                            log_writer.add_scalar('Train/batch_cost',
                                                  avg_train_batch_cost, iter)
                            log_writer.add_scalar('Train/reader_cost',
                                                  avg_train_reader_cost, iter)
                            log_writer.add_scalar('Train/tgt_label_acc',
                                                  label_tgt_acc, iter)

                        reader_cost_averager.reset()
                        batch_cost_averager.reset()

                    if (iter % save_interval == 0 or iter
                            == iters) and (val_dataset_tgt is not None):
                        num_workers = 4 if num_workers > 0 else 0  # adjust num_worker=4

                        if test_config is None:
                            test_config = {}
                        self.ema.apply_shadow()
                        self.ema.model.eval()

                        PA_tgt, _, MIoU_tgt, _ = val.evaluate(
                            self.model,
                            val_dataset_tgt,
                            num_workers=num_workers,
                            **test_config)

                        if (iter % (save_interval * 30)) == 0 \
                            and self.cfg['eval_src']:  # add evaluate on src
                            PA_src, _, MIoU_src, _ = val.evaluate(
                                self.model,
                                val_dataset_src,
                                num_workers=num_workers,
                                **test_config)
                            logger.info(
                                '[EVAL] The source mIoU is ({:.4f}) at iter {}.'
                                .format(MIoU_src, iter))

                        self.ema.restore()
                        self.model.train()

                    if (iter % save_interval == 0
                            or iter == iters) and local_rank == 0:
                        current_save_dir = os.path.join(
                            save_dir, "iter_{}".format(iter))
                        if not os.path.isdir(current_save_dir):
                            os.makedirs(current_save_dir)
                        paddle.save(
                            self.model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                        paddle.save(
                            self.ema.shadow,
                            os.path.join(current_save_dir,
                                         'model_ema.pdparams'))
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                        save_models.append(current_save_dir)
                        if len(save_models) > keep_checkpoint_max > 0:
                            model_to_remove = save_models.popleft()
                            shutil.rmtree(model_to_remove)

                        if val_dataset_tgt is not None:
                            if MIoU_tgt > best_mean_iou:
                                best_mean_iou = MIoU_tgt
                                best_model_iter = iter
                                best_model_dir = os.path.join(
                                    save_dir, "best_model")
                                paddle.save(
                                    self.model.state_dict(),
                                    os.path.join(best_model_dir,
                                                 'model.pdparams'))

                            logger.info(
                                '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
                                .format(best_mean_iou, best_model_iter))

                        if use_vdl:
                            log_writer.add_scalar('Evaluate/mIoU', MIoU_tgt,
                                                  iter)
                            log_writer.add_scalar('Evaluate/PA', PA_tgt, iter)

                            if self.cfg['eval_src']:
                                log_writer.add_scalar('Evaluate/mIoU_src',
                                                      MIoU_src, iter)
                                log_writer.add_scalar('Evaluate/PA_src',
                                                      PA_src, iter)

                    batch_start = time.time()

            self.ema.update_buffer()

        # # Calculate flops.
        if local_rank == 0:

            def count_syncbn(m, x, y):
                x = x[0]
                nelements = x.numel()
                m.total_ops += int(2 * nelements)

            _, c, h, w = images_src.shape
            flops = paddle.flops(
                self.model, [1, c, h, w],
                custom_ops={paddle.nn.SyncBatchNorm: count_syncbn})

        # Sleep for half a second to let dataloader release resources.
        time.sleep(0.5)
        if use_vdl:
            log_writer.close()