def train(FLAGS):
    # learner
    graph = Train_Graph(FLAGS)  
    graph.build()

    summary_op, tex_latent_summary_op, bg_latent_summary_op, eval_summary_op = Summary.collect_end2end_summary(graph, FLAGS)
    # train
    #define model saver
    with tf.name_scope("parameter_count"):
        total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
                                for v in tf.trainable_variables()])

    save_vars = tf.global_variables()
    # tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \
    #     tf.global_variables('VAE')+tf.global_variables('Fusion') \
    #     +tf.global_variables('train_op') #including global step
    
    if FLAGS.resume_CIS:
        CIS_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator')
        CIS_saver = tf.train.Saver(CIS_vars, max_to_keep=100)

    mask_saver = tf.train.Saver(tf.global_variables('VAE//separate/maskVAE/'), max_to_keep=100)
    tex_saver = tf.train.Saver(tf.global_variables('VAE//separate/texVAE/'), max_to_keep=100)

    saver = tf.train.Saver(save_vars, max_to_keep=100)
    branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir,'branch'+str(m))) for m in range(FLAGS.num_branch)]
    tex_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "tex_latent"+str(m))) for m in range(FLAGS.tex_dim)]
    bg_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "bg_latent"+str(m))) for m in range(FLAGS.bg_dim)]
    #mask_latent_writers =  [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "mask_latent"+str(m))) for m in range(FLAGS.mask_dim)]



    sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "end2end_Sum"),
                                 saver=None, save_summaries_secs=0)  #not saved automatically for flexibility

    with sv.managed_session() as sess:
        myprint ("Number of total params: {0} \n".format( \
            sess.run(total_parameter_count)))
        if FLAGS.resume_fullmodel:
            assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index')
            saver.restore(sess, FLAGS.fullmodel_ckpt)
            myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt))
            myprint ("Start from step {} vae_step{}".format(sess.run(graph.global_step), sess.run(graph.vae_global_step)))
            myprint ("Save checkpoint in          {}".format(FLAGS.checkpoint_dir))
            if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir:
                print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m")
            #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model")
        else:
            if os.path.isfile(FLAGS.mask_ckpt+'.index'):
                mask_saver.restore(sess, FLAGS.mask_ckpt)
                myprint ("Load pretrained maskVAE {}".format(FLAGS.mask_ckpt))
            if os.path.isfile(FLAGS.tex_ckpt+'.index'):   
                tex_saver.restore(sess, FLAGS.tex_ckpt)
                myprint ("Load pretrained texVAE {}".format(FLAGS.tex_ckpt))
            if FLAGS.resume_CIS:
                assert os.path.isfile(FLAGS.CIS_ckpt+'.index')  
                CIS_saver.restore(sess, FLAGS.CIS_ckpt)
                myprint ("Load pretrained inpainter and generator {}".format(FLAGS.CIS_ckpt))
            else:
                myprint ("Train from scratch")
        myinput('Press enter to continue')

        start_time = time.time()
        step = sess.run(graph.global_step)
        vae_step = sess.run(graph.vae_global_step)
        progbar = Progbar(target=FLAGS.ckpt_steps) #100k

        sum_iters = FLAGS.iters_gen_vae + FLAGS.iters_inp

        while (time.time()-start_time)<FLAGS.max_training_hrs*3600:
            if sv.should_stop():
                break

            fetches = {"global_step_inc": graph.incr_global_step, "step": graph.global_step}

            if step%sum_iters < FLAGS.iters_inp:
                fetches['train_op'] = graph.train_ops['Inpainter']
                mask_capacity = vae_step*FLAGS.mask_capacity_inc
            else:
                fetches['train_op'] = graph.train_ops #'VAE//separate/texVAE/','VAE//separate/texVAE_BG/', 'VAE//fusion', 'Fusion'
                mask_capacity = vae_step*FLAGS.mask_capacity_inc  #-> should have an VAE step
                fetches['vae_global_step'], fetches['vae_global_step_inc'] = graph.vae_global_step, graph.incr_vae_global_step

            if step % FLAGS.summaries_steps == 0:
                fetches["Inpainter_Loss"],fetches["Generator_Loss"] = graph.loss['Inpainter'], graph.loss['Generator']
                fetches["VAE//texVAE"], fetches["VAE//texVAE_BG"], fetches['VAE//fusion'] = graph.loss['VAE//separate/texVAE/'], graph.loss['VAE//separate/texVAE_BG/'], graph.loss['VAE//fusion']
                fetches['tex_kl'], fetches['bg_kl'] = graph.loss['tex_kl'], graph.loss['bg_kl']
                fetches['summary'] = summary_op

            if step % FLAGS.ckpt_steps == 0:
                fetches['generated_masks'] = graph.generated_masks
                fetches['GT_masks'] = graph.GT_masks

            results = sess.run(fetches, feed_dict={graph.is_training: True, graph.mask_capacity: mask_capacity})
            progbar.update(step%FLAGS.ckpt_steps)

            if step % FLAGS.summaries_steps == 0 :
                print ("   Step:%3dk time:%4.4fmin   VAELoss%4.2f" \
                    %(step/1000, (time.time()-start_time)/60, results["VAE//texVAE"]+results['VAE//fusion']+results['VAE//texVAE_BG']))
                sv.summary_writer.add_summary(results['summary'], step)

                for d in range(FLAGS.tex_dim):
                    tex_summary = sess.run(tex_latent_summary_op, feed_dict={graph.loss['tex_kl_var']: results['tex_kl'][d]})
                    tex_latent_writers[d].add_summary(tex_summary, step)
                    
                for d in range(FLAGS.bg_dim):
                    bg_summary = sess.run(bg_latent_summary_op, feed_dict={graph.loss['bg_kl_var']: results['bg_kl'][d]})
                    bg_latent_writers[d].add_summary(bg_summary, step)

                # for d in range(FLAGS.mask_dim):
                #     mask_summary = sess.run(mask_latent_summary_op, feed_dict={graph.loss['mask_kl_var']: results['mask_kl'][d]})
                #     mask_latent_writers[d].add_summary(mask_summary, step)
                
              
            if step % FLAGS.ckpt_steps == 0:
                saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=step)
                progbar = Progbar(target=FLAGS.ckpt_steps)

                #evaluation
                sess.run(graph.val_iterator.initializer)
                fetches = {'GT_masks':graph.GT_masks, 'generated_masks':graph.generated_masks}

                if FLAGS.dataset in ['multi_texture', 'flying_animals']:
                    #note that for multi_texture bg_num is just a fake number it represents number of samples for each type of image
                    score = [[]]*FLAGS.max_num
                    for bg in range(FLAGS.bg_num):
                        results_val=sess.run(fetches, feed_dict={graph.is_training: False})
                        for k in range(FLAGS.max_num):
                            #score[k].append(Permute_IoU(results_val['GT_masks'][k], results_val['generated_masks'][k]))
                            score[k] = score[k] + [Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k])]
                    for k in range(FLAGS.max_num):
                        eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: np.mean(score[k])})
                        branch_writers[k+1].add_summary(eval_summary, step)
                else:
                    num_sample = FLAGS.skipnum
                    niter = num_sample//FLAGS.batch_size
                    assert num_sample%FLAGS.batch_size==0
                    score = 0
                    for it in range(niter):
                        results_val = sess.run(fetches, feed_dict={graph.is_training:False})
                        for k in range(FLAGS.batch_size):
                            score += Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k])
                    score = score/num_sample
                    eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: score})
                    sv.summary_writer.add_summary(eval_summary, step)

            step = results['step']
            vae_step = results['vae_global_step']

        myprint("Training completed")