Example #1
0
def training_step(x, opt, first_batch):
    with tf.GradientTape() as tape:
        y = x * x
    grads = tape.gradient(y, [x])
    opt.apply_gradients(zip(grads, [x]))

    # KungFu: broadcast is done after the first gradient step to ensure optimizer initialization.
    if first_batch:
        broadcast_variables([x])
        broadcast_variables(opt.variables())

    return y
Example #2
0
def benchmark_step(first_batch):

    with tf.GradientTape() as tape:
        probs = model(data, training=True)
        loss = tf.losses.categorical_crossentropy(target, probs)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))

    if first_batch:
        from kungfu.tensorflow.initializer import broadcast_variables
        broadcast_variables(model.variables)
        broadcast_variables(opt.variables())
def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    # KungFu: broadcast is done after the first gradient step to ensure optimizer initialization.
    if first_batch:
        from kungfu.tensorflow.initializer import broadcast_variables
        broadcast_variables(mnist_model.variables)
        broadcast_variables(opt.variables())

    return loss_value
Example #4
0
 def one_step(image,gt_label,mask,train_model,is_first_batch=False):
     step.assign_add(1)
     with tf.GradientTape() as tape:
         gt_pif_maps,gt_paf_maps=gt_label
         pd_pif_maps,pd_paf_maps=train_model.forward(image,is_train=True)
         loss_pif_maps,loss_paf_maps,total_loss=train_model.cal_loss(pd_pif_maps,pd_paf_maps,gt_pif_maps,gt_paf_maps)
         decay_loss=regulize_loss(train_model,weight_decay_factor)
         total_loss+=decay_loss
     
     gradients=tape.gradient(total_loss,train_model.trainable_weights)
     opt.apply_gradients(zip(gradients,train_model.trainable_weights))
     #Kung fu
     if(is_first_batch):
         broadcast_variables(train_model.all_weights)
         broadcast_variables(opt.variables())
     return pd_pif_maps,pd_paf_maps,loss_pif_maps,loss_paf_maps,decay_loss,total_loss
Example #5
0
    def one_step(image,gt_label,mask,train_model,is_first_batch=False):
        step.assign_add(1)
        with tf.GradientTape() as tape:
            gt_conf=gt_label[:,:n_pos,:,:]
            gt_paf=gt_label[:,n_pos:,:,:]
            pd_conf,pd_paf,stage_confs,stage_pafs=train_model.forward(image,is_train=True)

            pd_loss,loss_confs,loss_pafs=train_model.cal_loss(gt_conf,gt_paf,mask,stage_confs,stage_pafs)
            re_loss=regulize_loss(train_model,weight_decay_factor)
            total_loss=pd_loss+re_loss
        
        gradients=tape.gradient(total_loss,train_model.trainable_weights)
        opt.apply_gradients(zip(gradients,train_model.trainable_weights))
        #Kung fu
        if(is_first_batch):
            broadcast_variables(train_model.all_weights)
            broadcast_variables(opt.variables())
        return gt_conf,gt_paf,pd_conf,pd_paf,total_loss,re_loss
Example #6
0
    def one_step(image,targets,train_model,is_first_batch=False):
        step.assign_add(1)
        with tf.GradientTape() as tape:
            delta,tx,ty,tw,th,te,te_mask=targets
            pc,pi,px,py,pw,ph,pe=train_model.forward(image,is_train=True)
            loss_rsp,loss_iou,loss_coor,loss_size,loss_limb=train_model.cal_loss(delta,tx,ty,tw,th,te,te_mask,pc,pi,px,py,pw,ph,pe)
            pd_loss=loss_rsp+loss_iou+loss_coor+loss_size+loss_limb
            re_loss=regulize_loss(train_model,weight_decay_factor)
            total_loss=pd_loss+re_loss

        gradients=tape.gradient(total_loss,train_model.trainable_weights)
        opt.apply_gradients(zip(gradients,train_model.trainable_weights))
        #Kung fu
        if(is_first_batch):
            broadcast_variables(train_model.all_weights)
            broadcast_variables(opt.variables())
        predicts=(pc,px,py,pw,ph,pe)
        return predicts,targets,pd_loss,re_loss,loss_rsp,loss_iou,loss_coor,loss_size,loss_limb
Example #7
0
def sync_offsets(xs):
    # TODO: use all_reduce with max op
    broadcast_variables(xs)
Example #8
0
def sync_model(model, opt):
    broadcast_variables(model.variables)
    broadcast_variables(opt.variables())
Example #9
0
def train(parallel, kungfu_option):
    Gab = models.get_G(name='Gab')
    Gba = models.get_G(name='Gba')
    Da = models.get_D(name='Da')
    Db = models.get_D(name='Db')

    Gab.train()
    Gba.train()
    Da.train()
    Db.train()

    lr_v = tf.Variable(flags.lr_init)
    # optimizer_Gab_Db = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_Gba_Da = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_G = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_D = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    optimizer = tf.optimizers.Adam(
        lr_v, beta_1=flags.beta_1
    )  # use only one optimier, if your GPU memory is large

    use_ident = False

    # KungFu: wrap the optimizers
    if parallel:
        from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer, SynchronousAveragingOptimizer, PairAveragingOptimizer
        if kungfu_option == 'sync-sgd':
            opt_fn = SynchronousSGDOptimizer
        elif kungfu_option == 'async-sgd':
            opt_fn = PairAveragingOptimizer
        elif kungfu_option == 'sma':
            opt_fn = SynchronousAveragingOptimizer
        else:
            raise RuntimeError('Unknown distributed training optimizer.')
        optimizer_Gab_Db = opt_fn(optimizer_Gab_Db)
        optimizer_Gba_Da = opt_fn(optimizer_Gba_Da)

    # Gab.load_weights(flags.model_dir + '/Gab.h5') # restore params?
    # Gba.load_weights(flags.model_dir + '/Gba.h5')
    # Da.load_weights(flags.model_dir + '/Da.h5')
    # Db.load_weights(flags.model_dir + '/Db.h5')

    # KungFu: shard the data
    if parallel:
        from kungfu import current_cluster_size, current_rank
        data_A_shard = []
        data_B_shard = []
        for step, (image_A, image_B) in enumerate(zip(data_A, data_B)):
            if step % current_cluster_size() == current_rank():
                data_A_shard.append(image_A)
                data_B_shard.append(image_B)
    else:
        data_A_shard = data_A
        data_B_shard = data_B

    @tf.function
    def train_step(image_A, image_B):
        fake_B = Gab(image_A)
        fake_A = Gba(image_B)
        cycle_A = Gba(fake_B)
        cycle_B = Gab(fake_A)
        if use_ident:
            iden_A = Gba(image_A)
            iden_B = Gab(image_B)
        logits_fake_B = Db(fake_B)  # TODO: missing image buffer (pool)
        logits_real_B = Db(image_B)
        logits_fake_A = Da(fake_A)
        logits_real_A = Da(image_A)
        # loss_Da = (tl.cost.mean_squared_error(logits_real_A, tf.ones_like(logits_real_A), is_mean=True) + \  # LSGAN
        #     tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True)) / 2.
        loss_Da = tf.reduce_mean(tf.math.squared_difference(logits_fake_A, tf.zeros_like(logits_fake_A))) + \
            tf.reduce_mean(tf.math.squared_difference(logits_real_A, tf.ones_like(logits_real_A)))
        # loss_Da = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.zeros_like(logits_fake_A)) + \
        # tl.cost.sigmoid_cross_entropy(logits_real_A, tf.ones_like(logits_real_A))
        # loss_Db = (tl.cost.mean_squared_error(logits_real_B, tf.ones_like(logits_real_B), is_mean=True) + \ # LSGAN
        #     tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True)) / 2.
        loss_Db = tf.reduce_mean(tf.math.squared_difference(logits_fake_B, tf.zeros_like(logits_fake_B))) + \
            tf.reduce_mean(tf.math.squared_difference(logits_real_B, tf.ones_like(logits_real_B)))
        # loss_Db = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.zeros_like(logits_fake_B)) + \
        #     tl.cost.sigmoid_cross_entropy(logits_real_B, tf.ones_like(logits_real_B))
        # loss_Gab = tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True) # LSGAN
        loss_Gab = tf.reduce_mean(
            tf.math.squared_difference(logits_fake_B,
                                       tf.ones_like(logits_fake_B)))
        # loss_Gab = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.ones_like(logits_fake_B))
        # loss_Gba = tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True) # LSGAN
        loss_Gba = tf.reduce_mean(
            tf.math.squared_difference(logits_fake_A,
                                       tf.ones_like(logits_fake_A)))
        # loss_Gba = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.ones_like(logits_fake_A))
        # loss_cyc = 10 * (tl.cost.absolute_difference_error(image_A, cycle_A, is_mean=True) + \
        #     tl.cost.absolute_difference_error(image_B, cycle_B, is_mean=True))
        loss_cyc = 10. * (tf.reduce_mean(tf.abs(image_A - cycle_A)) +
                          tf.reduce_mean(tf.abs(image_B - cycle_B)))

        if use_ident:
            loss_iden = 5. * (tf.reduce_mean(tf.abs(image_A - iden_A)) +
                              tf.reduce_mean(tf.abs(image_B - iden_B)))
        else:
            loss_iden = 0.

        loss_G = loss_Gab + loss_Gba + loss_cyc + loss_iden
        loss_D = loss_Da + loss_Db
        return loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_D + loss_G

    for epoch in range(0, flags.n_epoch):
        # reduce lr linearly after 100 epochs, from lr_init to 0
        if epoch >= 100:
            new_lr = flags.lr_init - flags.lr_init * (epoch - 100) / 100
            lr_v.assign(lr_v, new_lr)
            print("New learning rate %f" % new_lr)

        # train 1 epoch
        for step, (image_A,
                   image_B) in enumerate(zip(data_A_shard, data_B_shard)):
            if image_A.shape[0] != flags.batch_size or image_B.shape[
                    0] != flags.batch_size:  # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                # print(image_A.numpy().max())
                loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_DG = train_step(
                    image_A, image_B)

            grad = tape.gradient(
                loss_DG, Gba.trainable_weights + Gab.trainable_weights +
                Da.trainable_weights + Db.trainable_weights)
            optimizer.apply_gradients(
                zip(
                    grad, Gba.trainable_weights + Gab.trainable_weights +
                    Da.trainable_weights + Db.trainable_weights))
            # grad = tape.gradient(loss_G, Gba.trainable_weights+Gab.trainable_weights)
            # optimizer_G.apply_gradients(zip(grad, Gba.trainable_weights+Gab.trainable_weights))
            # grad = tape.gradient(loss_D, Da.trainable_weights+Db.trainable_weights)
            # optimizer_D.apply_gradients(zip(grad, Da.trainable_weights+Db.trainable_weights))

            # del tape
            print("Epoch[{}/{}] step[{}/{}] time:{:.3f} Gab:{:.3f} Gba:{:.3f} cyc:{:.3f} iden:{:.3f} Da:{:.3f} Db:{:.3f}".format(\
                epoch, flags.n_epoch, step, n_step_per_epoch, time.time()-step_time, \
                loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db))

            if parallel and step == 0:
                # KungFu: broadcast is done after the first gradient step to ensure optimizer initialization.
                from kungfu.tensorflow.initializer import broadcast_variables

                # Broadcast model variables
                broadcast_variables(Gab.trainable_weights)
                broadcast_variables(Gba.trainable_weights)
                broadcast_variables(Da.trainable_weights)
                broadcast_variables(Db.trainable_weights)

                # Broadcast optimizer variables
                broadcast_variables(optimizer_Gab.variables())
                broadcast_variables(optimizer_Gba.variables())
                broadcast_variables(optimizer_Da.variables())
                broadcast_variables(optimizer_Db.variables())

        if parallel:
            from kungfu import current_rank
            is_chief = current_rank() == 0
        else:
            is_chief = True

        # Let the chief worker to do visuliazation and checkpoints.
        if is_chief:
            # visualization

            # outb = Gab(sample_A)
            # outa = Gba(sample_B)
            # tl.vis.save_images(outb.numpy(), [1, 5], flags.sample_dir+'/{}_a2b.png'.format(epoch))
            # tl.vis.save_images(outa.numpy(), [1, 5], flags.sample_dir+'/{}_b2a.png'.format(epoch))

            outb_list = []  # do it one by one in case your GPU memory is low
            for i in range(len(sample_A)):
                outb = Gab(sample_A[i][np.newaxis, :, :, :])
                outb_list.append(outb.numpy()[0])

            outa_list = []
            for i in range(len(sample_B)):
                outa = Gba(sample_B[i][np.newaxis, :, :, :])
                outa_list.append(outa.numpy()[0])
            tl.vis.save_images(np.asarray(outb_list), [1, 5],
                               flags.sample_dir + '/{}_a2b.png'.format(epoch))
            tl.vis.save_images(np.asarray(outa_list), [1, 5],
                               flags.sample_dir + '/{}_b2a.png'.format(epoch))

            # save models
            if epoch % 5:
                Gab.save_weights(flags.model_dir + '/Gab.h5')
                Gba.save_weights(flags.model_dir + '/Gba.h5')
                Da.save_weights(flags.model_dir + '/Da.h5')
                Db.save_weights(flags.model_dir + '/Db.h5')
Example #10
0
def parallel_train(train_model, dataset, config, augmentor:BasicAugmentor, \
                        preprocessor:BasicPreProcessor,postprocessor:BasicPostProcessor,visualizer=BasicVisualizer):
    '''Single train pipeline of Openpose class models

    input model and dataset, the train pipeline will start automaticly
    the train pipeline will:
    1.store and restore ckpt in directory ./save_dir/model_name/model_dir
    2.log loss information in directory ./save_dir/model_name/log.txt
    3.visualize model output periodly during training in directory ./save_dir/model_name/train_vis_dir
    the newest model is at path ./save_dir/model_name/model_dir/newest_model.npz

    Parameters
    ----------
    arg1 : tensorlayer.models.MODEL
        a preset or user defined model object, obtained by Model.get_model() function
    
    arg2 : dataset
        a constructed dataset object, obtained by Dataset.get_dataset() function
    
    
    Returns
    -------
    None
    '''

    # train hyper params
    # dataset params
    total_step = config.train.n_step
    batch_size = config.train.batch_size
    # learning rate params
    lr_init = config.train.lr_init
    lr_decay_factor = config.train.lr_decay_factor
    lr_decay_steps = [
        200000, 300000, 360000, 420000, 480000, 540000, 600000, 700000, 800000,
        900000
    ]
    weight_decay_factor = config.train.weight_decay_factor
    # log and checkpoint params
    log_interval = config.log.log_interval
    vis_interval = config.train.vis_interval
    save_interval = config.train.save_interval
    vis_dir = config.train.vis_dir

    # model hyper params
    hin = train_model.hin
    win = train_model.win
    hout = train_model.hout
    wout = train_model.wout
    parts, limbs, colors = train_model.parts, train_model.limbs, train_model.colors
    data_format = train_model.data_format
    model_dir = config.model.model_dir
    pretrain_model_dir = config.pretrain.pretrain_model_dir
    pretrain_model_path = f"{pretrain_model_dir}/newest_{train_model.backbone.name}.npz"

    # metrics
    metric_manager = MetricManager()

    # initializing train dataset
    train_dataset = dataset.get_train_dataset()
    epoch_size = dataset.get_train_datasize() // batch_size
    paramed_map_fn = get_paramed_map_fn(augmentor=augmentor,
                                        preprocessor=preprocessor,
                                        data_format=data_format)
    train_dataset = train_dataset.shuffle(buffer_size=4096).repeat()
    train_dataset = train_dataset.map(
        paramed_map_fn, num_parallel_calls=get_num_parallel_calls())
    train_dataset = train_dataset.batch(config.train.batch_size)
    train_dataset = train_dataset.prefetch(3)
    train_dataset_iter = iter(train_dataset)

    #train configure
    save_step = tf.Variable(1, trainable=False)
    save_lr = tf.Variable(lr_init, trainable=False)
    opt = tf.keras.optimizers.Adam(learning_rate=save_lr)
    domainadapt_flag = config.data.domainadapt_flag
    total_epoch = total_step // epoch_size

    #domain adaptation params
    if (not domainadapt_flag):
        ckpt = tf.train.Checkpoint(save_step=save_step,
                                   save_lr=save_lr,
                                   opt=opt)
    else:
        log("Domain adaptaion in training enabled!")
        # weight param
        lambda_adapt = 1e-4
        # construct discrminator model
        feature_hin = train_model.hin // train_model.backbone.scale_size
        feature_win = train_model.win // train_model.backbone.scale_size
        in_channels = train_model.backbone.out_channels
        adapt_dis = Discriminator(feature_hin,
                                  feature_win,
                                  in_channels,
                                  data_format=data_format)
        opt_d = tf.keras.optimizers.Adam(learning_rate=save_lr)
        ckpt = tf.train.Checkpoint(save_step=save_step,
                                   save_lr=save_lr,
                                   opt=opt,
                                   opt_d=opt_d)
        # construct domain adaptation dataset
        dmadapt_train_dataset = dataset.get_dmadapt_train_dataset()
        paramed_dmadapt_map_fn = get_paramed_dmadapt_map_fn(augmentor)
        dmadapt_train_dataset = dmadapt_train_dataset.map(
            paramed_dmadapt_map_fn,
            num_parallel_calls=get_num_parallel_calls())
        dmadapt_train_dataset = dmadapt_train_dataset.shuffle(
            buffer_size=4096).repeat()
        dmadapt_train_dataset = dmadapt_train_dataset.batch(
            config.train.batch_size)
        dmadapt_train_dataset = dmadapt_train_dataset.prefetch(3)
        dmadapt_train_dataset_iter = iter(dmadapt_train_dataset)

    #load from ckpt
    ckpt_manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=3)
    try:
        log("loading ckpt...")
        ckpt.restore(ckpt_manager.latest_checkpoint)
    except:
        log("ckpt_path doesn't exist, step and optimizer are initialized")
    #load pretrained backbone
    try:
        log("loading pretrained backbone...")
        tl.files.load_and_assign_npz_dict(name=pretrain_model_path,
                                          network=train_model.backbone,
                                          skip=True)
    except:
        log("pretrained backbone doesn't exist, model backbone are initialized"
            )
    #load model weights
    try:
        log("loading saved training model weights...")
        train_model.load_weights(os.path.join(model_dir, "newest_model.npz"))
    except:
        log("model_path doesn't exist, model parameters are initialized")
    if (domainadapt_flag):
        try:
            log("loading saved domain adaptation discriminator weight...")
            adapt_dis.load_weights(
                os.path.join(model_dir, "newest_discriminator.npz"))
        except:
            log("discriminator path doesn't exist, discriminator parameters are initialized"
                )

    log(f"Parallel training using learning rate:{lr_init} batch_size:{batch_size}"
        )
    step = save_step.numpy()
    lr = save_lr.numpy()

    #import kungfu
    from kungfu.python import current_cluster_size, current_rank
    from kungfu.tensorflow.initializer import broadcast_variables
    from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer, SynchronousAveragingOptimizer, PairAveragingOptimizer

    total_step = total_step // current_cluster_size() + 1  # KungFu
    total_epoch = total_epoch // current_cluster_size() + 1  # KungFu
    for step_idx, decay_step in enumerate(lr_decay_steps):
        lr_decay_steps[
            step_idx] = decay_step // current_cluster_size() + 1  # KungFu

    # optimize one step
    def optimize_step(image, mask, target_x, train_model,
                      metric_manager: MetricManager):
        # tape
        with tf.GradientTape() as tape:
            predict_x = train_model.forward(x=image,
                                            is_train=True,
                                            ret_backbone=domainadapt_flag)
            total_loss = train_model.cal_loss(predict_x=predict_x, target_x=target_x, \
                                                        mask=mask, metric_manager=metric_manager)
        # optimize model
        gradients = tape.gradient(total_loss, train_model.trainable_weights)
        opt.apply_gradients(zip(gradients, train_model.trainable_weights))
        return predict_x

    def optimize_step_dmadapt(image_src, image_dst, train_model,
                              adapt_dis: Discriminator,
                              metric_manager: MetricManager):
        # tape
        with tf.GradientTape(persistent=True) as tape:
            # feature extraction
            # src feature
            predict_src = train_model.forward(x=image_src,
                                              is_train=True,
                                              ret_backbone=True)
            backbone_feature_src = predict_src["backbone_features"]
            adapt_pd_src = adapt_dis.forward(backbone_feature_src)
            # dst feature
            predict_dst = train_model.forward(x=image_dst,
                                              is_train=True,
                                              ret_backbone=True)
            backbone_feature_dst = predict_dst["backbone_features"]
            adapt_pd_dst = adapt_dis.forward(backbone_feature_dst)

            # loss calculation
            # loss of g
            g_adapt_loss = adapt_dis.cal_loss(x=adapt_pd_dst,
                                              label=True) * lambda_adapt
            # loss of d
            d_adapt_loss_src = adapt_dis.cal_loss(x=adapt_pd_src, label=True)
            d_adapt_loss_dst = adapt_dis.cal_loss(x=adapt_pd_dst, label=False)
            d_adapt_loss = (d_adapt_loss_src + d_adapt_loss_dst) / 2

        # optimize model
        g_gradient = tape.gradient(g_adapt_loss, train_model.trainable_weights)
        opt.apply_gradients(zip(g_gradient, train_model.trainable_weights))
        metric_manager.update("model/g_adapt_loss", g_adapt_loss)
        # optimize dis
        d_gradients = tape.gradient(d_adapt_loss, adapt_dis.trainable_weights)
        opt_d.apply_gradients(zip(d_gradients, adapt_dis.trainable_weights))
        metric_manager.update("dis/d_adapt_loss_src", d_adapt_loss_src)
        metric_manager.update("dis/d_adapt_loss_dst", d_adapt_loss_dst)
        # delete persistent tape
        del tape
        return predict_dst

    # formal training procedure

    # KungFu configure
    kungfu_option = config.train.kungfu_option
    if kungfu_option == KUNGFU.Sync_sgd:
        print("using Kungfu.SynchronousSGDOptimizer!")
        opt = SynchronousSGDOptimizer(opt)
    elif kungfu_option == KUNGFU.Sync_avg:
        print("using Kungfu.SynchronousAveragingOptimize!")
        opt = SynchronousAveragingOptimizer(opt)
    elif kungfu_option == KUNGFU.Pair_avg:
        print("using Kungfu.PairAveragingOptimizer!")
        opt = PairAveragingOptimizer(opt)

    train_model.train()
    cur_epoch = step // epoch_size + 1
    log(f"Start Training- total_epoch: {total_epoch} total_step: {total_step} current_epoch:{cur_epoch} "\
        +f"current_step:{step} batch_size:{batch_size} lr_init:{lr_init} lr_decay_steps:{lr_decay_steps} "\
        +f"lr_decay_factor:{lr_decay_factor} weight_decay_factor:{weight_decay_factor}" )
    for epoch_idx in range(cur_epoch, total_epoch):
        log(f"Epoch {epoch_idx}/{total_epoch}:")
        for _ in tqdm(range(0, epoch_size)):
            step += 1
            metric_manager.start_timing()
            image, mask, target_list = next(train_dataset_iter)
            # extract gt_label
            target_list = [
                cPickle.loads(target) for target in target_list.numpy()
            ]
            target_x = {key: [] for key, value in target_list[0].items()}
            target_x = reduce(
                lambda x, y:
                {key: x[key] + [y[key]]
                 for key, value in x.items()}, [target_x] + target_list)
            target_x = {
                key: np.stack(value)
                for key, value in target_x.items()
            }
            target_x = to_tensor_dict(target_x)

            # learning rate decay
            if (step in lr_decay_steps):
                new_lr_decay = lr_decay_factor**(lr_decay_steps.index(step) +
                                                 1)
                lr = lr_init * new_lr_decay

            # optimize one step
            predict_x = optimize_step(image, mask, target_x, train_model,
                                      metric_manager)

            # optimize domain adaptation
            if (domainadapt_flag):
                src_image = image
                dst_image = next(dmadapt_train_dataset_iter)
                predict_dst = optimize_step_dmadapt(src_image, dst_image,
                                                    train_model, adapt_dis,
                                                    metric_manager)

            if (step == 1):
                broadcast_variables(train_model.all_weights)
                broadcast_variables(opt.variables())

            # log info periodly
            if ((step != 0) and (step % log_interval) == 0):
                log(f"Train Epoch={epoch_idx} / {total_epoch}, Step={step} / {total_step}: learning_rate: {lr:.6e} {metric_manager.report_timing()}\n"\
                        +f"{metric_manager.report_train()} ")

            # visualize periodly
            if ((step != 0) and (step % vis_interval) == 0
                    and current_rank() == 0):
                log(f"Visualizing prediction maps and target maps")
                visualizer.visual_compare(image_batch=image.numpy(), mask_batch=mask.numpy(), predict_x=predict_x, target_x=target_x,\
                                                    name=f"train_{step}")

            # save result and ckpt periodly
            if ((step != 0) and (step % save_interval) == 0
                    and current_rank() == 0):
                # save ckpt
                log("saving model ckpt and result...")
                save_step.assign(step)
                save_lr.assign(lr)
                ckpt_save_path = ckpt_manager.save()
                log(f"ckpt save_path:{ckpt_save_path} saved!\n")
                # save train model
                model_save_path = os.path.join(model_dir, "newest_model.npz")
                train_model.save_weights(model_save_path)
                log(f"model save_path:{model_save_path} saved!\n")
                # save discriminator model
                if (domainadapt_flag):
                    dis_save_path = os.path.join(model_dir,
                                                 "newest_discriminator.npz")
                    adapt_dis.save_weights(dis_save_path)
                    log(f"discriminator save_path:{dis_save_path} saved!\n")