Esempio n. 1
0
def train(FLAGS):
    graph = Train_Graph(FLAGS)
    graph.build()

    summary_op, generator_summary_op, branch_summary_op, eval_summary_op = Summary.collect_CIS_summary(
        graph, FLAGS)
    with tf.name_scope("parameter_count"):
        total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
                                for v in tf.trainable_variables()])

    save_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \
        tf.global_variables('train_op') #including global step
    if FLAGS.resume_inpainter:
        assert os.path.isfile(FLAGS.inpainter_ckpt + '.index')
        inpainter_saver = tf.train.Saver(tf.trainable_variables(
            'Inpainter'))  #only restore the trainable variables

    if FLAGS.resume_resnet:
        assert os.path.isfile(FLAGS.resnet_ckpt)
        resnet_reader = tf.compat.v1.train.NewCheckpointReader(
            FLAGS.resnet_ckpt)
        resnet_map = resnet_reader.get_variable_to_shape_map()
        resnet_dict = dict()
        for v in tf.trainable_variables('Generator//resnet_v2'):
            if 'resnet_v2_50/' + v.op.name[21:] in resnet_map.keys():
                resnet_dict['resnet_v2_50/' + v.op.name[21:]] = v
        resnet_var_name = [v.name for v in tf.trainable_variables('Generator//resnet_v2') \
                if 'resnet_v2_50/'+v.op.name[21:] in resnet_map.keys()]
        resnet_saver = tf.train.Saver(resnet_dict)

    saver = tf.train.Saver(save_vars, max_to_keep=100)
    branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "branch"+str(m))) \
        for m in range(FLAGS.num_branch)] #save generator loss for each branch
    sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir,
                                                 "CIS_Sum"),
                             saver=None,
                             save_summaries_secs=0)

    with sv.managed_session() as sess:
        myprint ("Number of total params: {0} \n".format( \
            sess.run(total_parameter_count)))
        if FLAGS.resume_fullmodel:
            assert os.path.isfile(FLAGS.fullmodel_ckpt + '.index')
            saver.restore(sess, FLAGS.fullmodel_ckpt)
            myprint("Resumed training from model {}".format(
                FLAGS.fullmodel_ckpt))
            myprint("Start from step {}".format(sess.run(graph.global_step)))
            myprint("Save checkpoint in          {}".format(
                FLAGS.checkpoint_dir))
            if not os.path.dirname(
                    FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir:
                print(
                    "\033[0;30;41m" +
                    "Warning: checkpoint dir and fullmodel ckpt do not match" +
                    "\033[0m")
                myprint(
                    "Please make sure that new checkpoint will be saved in the same dir with the resumed model"
                )
        else:
            if FLAGS.resume_inpainter:
                assert os.path.isfile(FLAGS.inpainter_ckpt + '.index')
                inpainter_saver.restore(sess, FLAGS.inpainter_ckpt)
                myprint("Load pretrained inpainter {}".format(
                    FLAGS.inpainter_ckpt))

            if FLAGS.resume_resnet:
                resnet_saver.restore(sess, FLAGS.resnet_ckpt)
                myprint("Load pretrained resnet {}".format(FLAGS.resnet_ckpt))
            if not FLAGS.resume_resnet and not FLAGS.resume_inpainter:
                myprint("Train from scratch")
        myinput('Press enter to continue')

        start_time = time.time()
        step = sess.run(graph.global_step)
        progbar = Progbar(target=FLAGS.ckpt_steps)  #100k

        sum_iters = FLAGS.iters_gen + FLAGS.iters_inp

        while (time.time() - start_time) < FLAGS.max_training_hrs * 3600:
            if sv.should_stop():
                break

            fetches = {
                "global_step_inc": graph.incr_global_step,
                "step": graph.global_step
            }

            if step % sum_iters < FLAGS.iters_inp:
                fetches['train_op'] = graph.train_ops['Inpainter']
            else:
                fetches['train_op'] = graph.train_ops['Generator']

            if step % FLAGS.summaries_steps == 0:
                fetches["Inpainter_Loss"], fetches["Generator_Loss"] \
                    = graph.loss['Inpainter'], graph.loss['Generator']
                fetches["Inpainter_branch_Loss"], fetches["Generator_branch_Loss"] \
                    = graph.loss['Inpainter_branch'], graph.loss['Generator_branch']
                fetches['Generator_Loss_denominator'] = graph.loss[
                    'Generator_denominator']
                fetches['summary'] = summary_op

            if step % FLAGS.ckpt_steps == 0:
                fetches['generated_masks'] = graph.generated_masks
                fetches['GT_masks'] = graph.GT_masks

            results = sess.run(fetches, feed_dict={graph.is_training: True})
            progbar.update(step % FLAGS.ckpt_steps)

            if step % FLAGS.summaries_steps == 0:
                print ("   Step:%3dk time:%4.4fmin   InpainterLoss%4.2f  GeneratorLoss%4.2f " \
                    %(step/1000, (time.time()-start_time)/60, results['Inpainter_Loss'], results['Generator_Loss']))
                sv.summary_writer.add_summary(results['summary'], step)

                generator_summary = sess.run(generator_summary_op,
                                             feed_dict={
                                                 graph.loss['Generator_var']:
                                                 results['Generator_Loss']
                                             })
                sv.summary_writer.add_summary(generator_summary, step)

                for m in range(FLAGS.num_branch):
                    branch_summary = sess.run(
                        branch_summary_op,
                        feed_dict={
                            graph.loss['Inpainter_branch_var']:
                            np.mean(results['Inpainter_branch_Loss'][:, m],
                                    axis=0),
                            graph.loss['Generator_branch_var']:
                            np.mean(results['Generator_branch_Loss'][:, m],
                                    axis=0),
                            graph.loss['Generator_denominator_var']:
                            np.mean(results['Generator_Loss_denominator'][:,
                                                                          m],
                                    axis=0)
                        })
                    branch_writers[m].add_summary(branch_summary, step)

            if step % FLAGS.ckpt_steps == 0:
                saver.save(sess,
                           os.path.join(FLAGS.checkpoint_dir, 'model'),
                           global_step=step)
                progbar = Progbar(target=FLAGS.ckpt_steps)

                #evaluation
                sess.run(graph.val_iterator.initializer)
                fetches = {
                    'GT_masks': graph.GT_masks,
                    'generated_masks': graph.generated_masks
                }

                num_sample = 200
                niter = num_sample // FLAGS.batch_size
                score = 0
                for it in range(niter):
                    results_val = sess.run(
                        fetches, feed_dict={graph.is_training: False})
                    for k in range(FLAGS.batch_size):
                        score += Permute_IoU(
                            label=results_val['GT_masks'][k],
                            pred=results_val['generated_masks'][k])
                score = score / num_sample
                eval_summary = sess.run(
                    eval_summary_op,
                    feed_dict={graph.loss['EvalIoU_var']: score})
                sv.summary_writer.add_summary(eval_summary, step)
            step = results['step']

        myprint("Training completed")