Ejemplo n.º 1
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(
                    data, i, input_nc, output_nc)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_prev_last)

                ####### discriminator
                ### individual frame discriminator
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Ejemplo n.º 2
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    # Initialize dataset:======================================================#
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('Number of training videos = %d' % dataset_size)

    # Initialize models:=======================================================#
    models = prepare_models(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = \
            create_optimizer(opt, models)

    # Set parameters:==========================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = \
    init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    # Initialize loss list:====================================================#
    losses_G = []
    losses_D = []
    losses_D_T = []
    losses_t_scales = []

    # Real training starts here:===============================================#
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = \
                data_loader.dataset.init_data_params(data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, input_C, inst_A = \
                        data_loader.dataset.prepare_data(data, i, input_nc, output_nc)

                ############################### Forward Pass ###############################
                ####### Generator:=========================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = \
                        modelG(input_A, input_B, inst_A, fake_B_prev_last)

                ####### Discriminator:=====================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref, input_C
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                modelD.module.get_all_skipped_frames(frames_all,
                real_B,
                                                     fake_B,
                                                     flow_ref,
                                                     conf_ref,
                                                     t_scales,
                                                     tD,
                                                     n_frames_load,
                                                     i,
                                                     flowNet)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = \
                                modelD.module.get_losses(loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])
                # the first generated image in this sequence
                if i == 0: fake_B_first = fake_B[0, 0]

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

    # Display results and errors:==============================================#
    # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                                           else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                                   if not isinstance(v, int) \
                                   else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batch_size:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' % \
       (epoch, opt.niter + opt.niter_decay,
                             time.time() - epoch_start_time))

        ### save model for this epoch and update model params:=================#
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator loss: %f." % losses_D[-1])
        #Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt.name + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Ejemplo n.º 3
0
def train():
    opt = vars(TrainOptions().parse())
    # Initialize dataset:==============================================================================================#
    data_loader = create_dataloader(**opt)
    dataset_size = len(data_loader)
    print(f'Number of training videos = {dataset_size}')
    # Initialize models:===============================================================================================#
    models = prepare_models(**opt)
    model_g, model_d, flow_net, optimizer_g, optimizer_d, optimizer_d_t = create_optimizer(
        models, **opt)
    # Set parameters:==================================================================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(model_g, model_d, data_loader, **opt)
    visualizer = Visualizer(**opt)

    # Initialize loss list:============================================================================================#
    losses_G = []
    losses_D = []
    # Training start:==================================================================================================#
    for epoch in range(start_epoch, opt['niter'] + opt['niter_decay'] + 1):
        epoch_start_time = time.time()
        for idx, video in enumerate(data_loader, start=epoch_iter):
            if not total_steps % print_freq:
                iter_start_time = time.time()
            total_steps += opt['batch_size']
            epoch_iter += opt['batch_size']

            # whether to collect output images
            save_fake = total_steps % opt['display_freq'] == 0
            fake_B_prev_last = None
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            if opt['sparse_D']:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            frames_all = real_B_all, fake_B_all, flow_ref_all, conf_ref_all

            for i, (input_A, input_B) in enumerate(VideoSeq(**video, **opt)):

                # Forward Pass:========================================================================================#
                # Generator:===========================================================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = model_g(
                    input_A, input_B, fake_B_prev_last)

                # Discriminator:=======================================================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flow_net(real_B, real_B_prev)
                fake_B_prev = model_g.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = model_d(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(model_d.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                    model_d.get_all_skipped_frames(frames_all,
                                                  real_B,
                                                  fake_B,
                                                  flow_ref,
                                                  conf_ref,
                                                  t_scales,
                                                  tD,
                                                  video.n_frames_load,
                                                  i,
                                                  flow_net)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = model_d(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(model_d.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = model_d.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(loss_G, optimizer_g)

                # update individual discriminator weights
                loss_backward(loss_D, optimizer_d)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_d_t[s])
                # the first generated image in this sequence
                if i == 0:
                    fake_B_first = fake_B[0, 0]

            # Display results and errors:==============================================#
            # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                    else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                        if not isinstance(v, int) \
                        else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, model_d)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(epoch, epoch_iter, total_steps, visualizer, iter_path,
                        model_g, model_d, **opt)
            if epoch_iter > dataset_size - opt['batch_size']:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print(
            f'End of epoch {epoch} / {opt["niter"] + opt["niter_decay"]} \t'
            f' Time Taken: {time.time() - epoch_start_time} sec')

        # save model for this epoch and update model params:=================#
        save_models(epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    model_g,
                    model_d,
                    end_of_epoch=True,
                    **opt)
        update_models(epoch, model_g, model_d, data_loader, **opt)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator Loss: %f." % losses_D[-1])
        # Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt['name'] + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Ejemplo n.º 4
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1    
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)    
    print('#training frames = %d' % dataset_size)

    ### initialize models
    models = create_model_cloth(opt)
    ClothWarper, ClothWarperLoss, flowNet, optimizer = create_optimizer_cloth(opt, models)

    ### set parameters    
    n_gpus, tG, input_nc_1, input_nc_2, input_nc_3, start_epoch, epoch_iter, print_freq, total_steps, iter_path, tD, t_scales = init_params(opt, ClothWarper, data_loader)
    visualizer = Visualizer(opt)    

    ### real training starts here  
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()    
        for idx, data in enumerate(dataset, start=epoch_iter):        
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params_cloth(data, n_gpus, tG)
            flow_total_prev_last, frames_all = data_loader.dataset.init_data_cloth(t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                is_first_frame = flow_total_prev_last is None
                input_TParsing, input_TFG, input_SParsing, input_SFG, input_SFG_full = data_loader.dataset.prepare_data_cloth(data, i)

                ###################################### Forward Pass ##########################
                ####### C2F-FWN                  
                fg_tps, fg_dense, lo_tps, lo_dense, flow_tps, flow_dense, flow_totalp, real_input_1, real_input_2, real_SFG, real_SFG_fullp, flow_total_last = ClothWarper(input_TParsing, input_TFG, input_SParsing, input_SFG, input_SFG_full, flow_total_prev_last)
                real_SLO = real_input_2[:, :, -opt.label_nc_2:]
                ####### compute losses
                ### individual frame losses and FTC loss with l=1
                real_SFG_full_prev, real_SFG_full = real_SFG_fullp[:, :-1], real_SFG_fullp[:, 1:]   # the collection of previous and current real frames
                flow_optical_ref, conf_optical_ref = flowNet(real_SFG_full, real_SFG_full_prev)       # reference flows and confidences                
                
                flow_total_prev, flow_total = flow_totalp[:, :-1], flow_totalp[:, 1:]
                if is_first_frame:
                    flow_total_prev = flow_total_prev[:, 1:]

                flow_total_prev_last = flow_total_last
                
                losses, flows_sampled_0 = ClothWarperLoss(0, reshape([real_SFG, real_SLO, fg_tps, fg_dense, lo_tps, lo_dense, flow_tps, flow_dense, flow_total, flow_total_prev, flow_optical_ref, conf_optical_ref]), is_first_frame)
                losses = [ torch.mean(x) if x is not None else 0 for x in losses ]
                loss_dict = dict(zip(ClothWarperLoss.module.loss_names, losses))          

                ### FTC losses with l=3,9
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = ClothWarperLoss.module.get_all_skipped_frames(frames_all, \
                        real_SFG_full, flow_total, flow_optical_ref, conf_optical_ref, real_SLO, t_scales, tD, n_frames_load, i, flowNet)                                

                # compute losses for l=3,9
                loss_dict_T = []
                for s in range(1, t_scales):                
                    if frames_skipped[0][s] is not None and not opt.tps_only:                        
                        losses, flows_sampled_1 = ClothWarperLoss(s+1, [frame_skipped[s] for frame_skipped in frames_skipped], False)
                        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
                        loss_dict_T.append(dict(zip(ClothWarperLoss.module.loss_names_T, losses)))                  

                # collect losses
                loss, _ = ClothWarperLoss.module.get_losses(loss_dict, loss_dict_T, t_scales-1)

                ###################################### Backward Pass #################################                 
                # update generator weights     
                loss_backward(opt, loss, optimizer)                

                if i == 0: fg_dense_first = fg_dense[0, 0]   # the first generated image in this sequence


            if opt.debug:
                call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k+str(s): v.data.item() if not isinstance(v, int) else v for k, v in loss_dict_T[s].items()})

                loss_names_vis = ClothWarperLoss.module.loss_names.copy()
                {loss_names_vis.append(ClothWarperLoss.module.loss_names_T[0]+str(idx)) for idx in range(len(loss_dict_T))}
                visualizer.print_current_errors_new(epoch, epoch_iter, errors, loss_names_vis, t)
                visualizer.plot_current_errors(errors, total_steps)
            ### display output images
            if save_fake:                
                visuals = util.save_all_tensors_cloth(opt, real_input_1, real_input_2, fg_tps, fg_dense, lo_tps, lo_dense, fg_dense_first, real_SFG, real_SFG_full, flow_tps, flow_dense, flow_total)            
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models_cloth(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, ClothWarper)            
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break
           
        # end of epoch 
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models_cloth(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, ClothWarper, end_of_epoch=True)
        update_models_cloth(opt, epoch, ClothWarper, data_loader)