コード例 #1
0
ファイル: main.py プロジェクト: BigmathAI/nn-rotobrush
def main():
    dl_train = data_layer_2d(
        FLAGS, 'train') if FLAGS.mode in ['train', 'finetune'] else None
    dl_valid = data_layer_2d(FLAGS, 'valid')

    net = CENet(FLAGS)

    with tf.Session(config=TFCONFIG) as sess:
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        sw = solver_wrapper(net, (dl_train, dl_valid), sess, FLAGS)
        if FLAGS.mode == 'train':
            sw.Train()
        elif FLAGS.mode == 'finetune':
            sw.Finetune()
        elif FLAGS.mode == 'valid':
            total_loss_vals = sw.Evaluate()
            [
                logger.info(line)
                for line in pyutils.dict_to_string(total_loss_vals, 3)
            ]
        elif FLAGS.mode == 'findbest':
            sw.FindBestModel()
        else:
            print('Do Nothing')
コード例 #2
0
 def log_at_every_itera_end_time(self, status, crt_phase, prgbar, loss_vals,
                                 times):
     time_cost_data, time_cost_optim = times
     line = 'Ep[{:03d}/{:03d}] Iter{: 6d}, {}, (D{:3.2f}s R{:3.2f}s), Ph: {:8s}, Loss: {:6.4f}'.format(
         status.epoch, self.FLAGS.epoches, status.iteration,
         prgbar.GetBar(), time_cost_data, time_cost_optim, crt_phase,
         loss_vals[crt_phase])
     logger.info(line)
     lines = pyutils.dict_to_string(loss_vals, 3)
     for line in lines:
         logger.info(line)
コード例 #3
0
    def FindBestModel(self):
        fname_ckpt_orig = os.path.join(self.FLAGS.log_path, 'checkpoint')
        fname_ckpt_backup = os.path.join(self.FLAGS.log_path,
                                         'checkpoint-backup')
        try:
            cmd = 'cp {} {}'.format(fname_ckpt_orig, fname_ckpt_backup)
            os.system(cmd)
        except:
            logger.warning('The original checkpoint file not backup-ed!')
            return

        try:
            #self.restore_best_model()
            fname_ckpt_models = fp.dir(self.FLAGS.log_path,
                                       '.meta',
                                       case_sensitive=True)
            fname_ckpt_models = fname_ckpt_models[::-1]
            for fname in fname_ckpt_models:
                if 'temp_model' in fname:
                    continue
                fname = fname.rstrip('.meta')
                fname = fname.replace('\\', '/')
                print(fname)
                with open(fname_ckpt_orig, 'w') as f:
                    f.write('model_checkpoint_path: "{}"\n'.format(fname))
                    f.write('all_model_checkpoint_paths: "{}"\n'.format(fname))
                eval_loss_vals = self.Evaluate()
                [
                    logger.info(line)
                    for line in pyutils.dict_to_string(eval_loss_vals, 3)
                ]
                self.record_best_model(self.best_model_table, eval_loss_vals,
                                       fname)
                print(fname + ' END')
        except Exception as e:
            traceback.print_exc()
            cmd = 'cp {} {}'.format(fname_ckpt_backup, fname_ckpt_orig)
            os.system(cmd)
            logger.error('ERROR')
        else:
            cmd = 'cp {} {}'.format(fname_ckpt_backup, fname_ckpt_orig)
            os.system(cmd)
コード例 #4
0
    def while_loop(self, saver):
        train_data = self.train_data

        while train_data.status.epoch < self.FLAGS.epoches:
            status = edict(**train_data.status)  # <== Deep Copy
            train_data.export_status()

            crt_phase = self.compute_crt_phase(status)

            tic = time.time()
            feed_dict, valid_len = self.extract_data_and_build_feed_dict(
                train_data)
            toc = time.time()
            time_cost_data = toc - tic

            ops = {
                **self.net.out,
                **self.net.loss,
                **self.net.summary,
                crt_phase + '_optim':
                self.net.optims[crt_phase],
                'ph_lr':
                self.net.ph_lr,
            }

            tic = time.time()
            op_vals = self.sess.run(ops, feed_dict=feed_dict)
            toc = time.time()
            time_cost_optim = toc - tic

            draw_vals = {
                k: op_vals[k][:valid_len]
                for k in self.net.out.keys()
            }
            loss_vals = {k: op_vals[k] for k in self.net.loss.keys()}
            smry_vals = {k: op_vals[k] for k in self.net.summary.keys()}

            self.log_at_every_itera_end_time(status, crt_phase,
                                             train_data.prgbar, loss_vals,
                                             (time_cost_data, time_cost_optim))

            if status.iteration % 100 == 0:
                pyutils.print_params(logger, self.FLAGS)
                os.system('nvidia-smi')
            if status.iteration % 20 == 0:
                utils.draw_ims(draw_vals, self.fdout_temp_image)
            if status.iteration % 5 == 0:
                [
                    self.summary_writer_train.add_summary(v, status.iteration)
                    for _, v in smry_vals.items()
                ]
            if status.iteration % 5 == 0 and os.path.exists('stop'):
                raise ValueError('Stop file exists!!! Quit!!!')
            if train_data.status.epoch != status.epoch:
                if status.epoch % 5 == 0:
                    eval_loss_vals = self.Evaluate(status.epoch)
                    logger.info('Eval: Ph {:>12s}: {:6.4f}'.format(
                        crt_phase, eval_loss_vals[crt_phase]))
                    [
                        logger.info(line)
                        for line in pyutils.dict_to_string(eval_loss_vals, 3)
                    ]
                    fname_ckpt_model = self.fname_ckpt_model.format(crt_phase)
                    fname_ckpt_model = saver.save(self.sess, fname_ckpt_model,
                                                  status.epoch)
                    need_to_save = self.record_best_model(
                        self.best_model_table, eval_loss_vals,
                        fname_ckpt_model)
                    if not need_to_save:
                        cmd = 'rm {}.*'.format(fname_ckpt_model)
                        logger.warning('DELETE MODEL: ' + cmd)
                        os.system(cmd)
                    self.pgrbar.Update(1)
                    logger.info('ENTIRE-PROGRESS: {}\n'.format(
                        self.pgrbar.GetBar()))
コード例 #5
0
    def while_loop(self, saver):
        train_data = self.train_data

        while train_data.status.epoch < self.FLAGS.epoches:
            status = edict(**train_data.status)  # <== Deep Copy
            train_data.export_status()

            crt_phase = self.compute_crt_phase(status)

            tic = time.time()
            feed_dict, valid_len = self.extract_data_and_build_feed_dict(
                train_data)
            toc = time.time()
            time_cost_data = toc - tic

            ops = {
                **self.net.out,
                **self.net.loss,
                **self.net.summary,
                crt_phase + '_optim':
                self.net.optims[crt_phase],
                'ph_lr':
                self.net.ph_lr,
            }

            tic = time.time()
            op_vals = self.sess.run(ops, feed_dict=feed_dict)
            toc = time.time()
            time_cost_optim = toc - tic

            draw_vals = {
                k: op_vals[k][:valid_len]
                for k in self.net.out.keys()
            }
            loss_vals = {k: op_vals[k] for k in self.net.loss.keys()}
            smry_vals = {k: op_vals[k] for k in self.net.summary.keys()}

            self.log_at_every_itera_end_time(status, crt_phase,
                                             train_data.prgbar, loss_vals,
                                             (time_cost_data, time_cost_optim))

            from tensorflow.python.framework.graph_util import convert_variables_to_constants
            #tmp = self.sess.graph_def
            #with open('debug.node.txt', 'w') as f:
            #    for x in tmp.node:
            #        f.write(x.name + '\n')
            #ipdb.set_trace()
            graph = convert_variables_to_constants(
                self.sess, self.sess.graph_def,
                ['cpu_variables/generator_im/segmented'])
            tf.train.write_graph(graph, 'models', 'model.pb', as_text=False)

            if status.iteration % 100 == 0:
                pyutils.print_params(logger, self.FLAGS)
                os.system('nvidia-smi')
            if status.iteration % 100 == 0:
                utils.draw_ims(draw_vals, self.fdout_temp_image)
            if status.iteration % 5 == 0:
                [
                    self.summary_writer_train.add_summary(v, status.iteration)
                    for _, v in smry_vals.items()
                ]
            if status.iteration % 5 == 0 and os.path.exists('stop'):
                raise ValueError('Stop file exists!!! Quit!!!')
            if train_data.status.epoch != status.epoch and status.epoch % 5 == 0:
                eval_loss_vals = self.Evaluate(status.epoch)
                logger.info('Eval: Ph {:>12s}: {:6.4f}'.format(
                    crt_phase, eval_loss_vals[crt_phase]))
                [
                    logger.info(line)
                    for line in pyutils.dict_to_string(eval_loss_vals, 3)
                ]
                fname_ckpt_model = self.fname_ckpt_model.format(crt_phase)
                fname_ckpt_model = saver.save(self.sess, fname_ckpt_model,
                                              status.epoch)
                need_to_save = self.record_best_model(self.best_model_table,
                                                      eval_loss_vals,
                                                      fname_ckpt_model)
                if not need_to_save:
                    cmd = 'rm {}.*'.format(fname_ckpt_model)
                    logger.warning('DELETE MODEL: ' + cmd)
                    os.system(cmd)
                self.pgrbar.Update(1)
                logger.info('ENTIRE-PROGRESS: {}\n'.format(
                    self.pgrbar.GetBar()))