def __init__(self,
                 config,
                 dataset):
        self.config = config
        self.train_dir = config.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        self.dataset = dataset

        check_data_id(dataset, config.data_id)
        _, self.batch = create_input_ops(dataset, self.batch_size,
                                         data_id=config.data_id,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        tf.set_random_seed(1234)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint_path = config.checkpoint_path
        if self.checkpoint_path is None and self.train_dir:
            self.checkpoint_path = tf.train.latest_checkpoint(self.train_dir)

        if self.checkpoint_path is None:
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            log.info("Checkpoint path : %s", self.checkpoint_path)

        mean_std = np.load('../DatasetCreation/VG/mean_std.npz')
        self.img_mean = mean_std['img_mean']
        self.img_std = mean_std['img_std']
        self.coords_mean = mean_std['coords_mean']
        self.coords_std = mean_std['coords_std']
Beispiel #2
0
class Evaler(object):
    @staticmethod
    def get_model_class(model_name):
        if model_name == 'baseline':
            from model_baseline import Model
        elif model_name == 'rn':
            from model_rn import Model
        else:
            raise ValueError(model_name)
        return Model

    def __init__(self, config, dataset):
        self.config = config
        self.train_dir = config.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        self.dataset = dataset

        check_data_id(dataset, config.data_id)
        _, self.batch = create_input_ops(dataset,
                                         self.batch_size,
                                         data_id=config.data_id,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        tf.set_random_seed(1234)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint_path = config.checkpoint_path
        if self.checkpoint_path is None and self.train_dir:
            self.checkpoint_path = tf.train.latest_checkpoint(self.train_dir)
        if self.checkpoint_path is None:
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            log.info("Checkpoint path : %s", self.checkpoint_path)

    def eval_run(self):
        # load checkpoint
        if self.checkpoint_path:
            self.saver.restore(self.session, self.checkpoint_path)
            log.info("Loaded from checkpoint!")

        log.infov("Start 1-epoch Inference and Evaluation")

        log.info("# of examples = %d", len(self.dataset))
        length_dataset = len(self.dataset)

        max_steps = int(length_dataset / self.batch_size) + 1
        log.info("max_steps = %d", max_steps)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord,
                                               start=True)

        evaler = EvalManager()
        try:
            for s in xrange(max_steps):
                step, loss, step_time, batch_chunk, prediction_pred, prediction_gt = \
                    self.run_single_step(self.batch)
                self.log_step_message(s, loss, step_time)
                evaler.add_batch(batch_chunk['id'], prediction_pred,
                                 prediction_gt)

        except Exception as e:
            coord.request_stop(e)

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))

        evaler.report(name=self.config.name)
        log.infov("Evaluation complete.")

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        batch_chunk = self.session.run(batch)

        [step, accuracy, all_preds, all_targets,
         _] = self.session.run([
             self.global_step, self.model.accuracy, self.model.all_preds,
             self.model.a, self.step_op
         ],
                               feed_dict=self.model.get_feed_dict(batch_chunk))

        _end_time = time.time()

        return step, accuracy, (
            _end_time - _start_time), batch_chunk, all_preds, all_targets

    def log_step_message(self, step, accuracy, step_time, is_train=False):
        if step_time == 0: step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((
            " [{split_mode:5s} step {step:4d}] " +
            "batch total-accuracy (test): {test_accuracy:.2f}% " +
            "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
        ).format(
            split_mode=(is_train and 'train' or 'val'),
            step=step,
            test_accuracy=accuracy * 100,
            sec_per_batch=step_time,
            instance_per_sec=self.batch_size / step_time,
        ))
Beispiel #3
0
class Trainer(object):
    @staticmethod
    def get_model_class(model_name):
        if model_name == 'baseline':
            from model_baseline import Model
        elif model_name == 'rn':
            from model_rn import Model
        else:
            raise ValueError(model_name)
        return Model

    def __init__(self, config, dataset, dataset_test):
        self.config = config
        hyper_parameter_str = config.dataset_path + '_lr_' + str(
            config.learning_rate)
        self.train_dir = './train_dir/%s-%s-%s-%s' % (
            config.model, config.prefix, hyper_parameter_str,
            time.strftime("%Y%m%d-%H%M%S"))

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        _, self.batch_train = create_input_ops(dataset,
                                               self.batch_size,
                                               is_training=True)
        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.learning_rate = config.learning_rate
        if config.lr_weight_decay:
            self.learning_rate = tf.train.exponential_decay(
                self.learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True,
                name='decaying_learning_rate')

        self.check_op = tf.no_op()

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=20.0,
            name='optimizer_loss')

        self.summary_op = tf.summary.merge_all()
        try:
            import tfplot
            self.plot_summary_op = tf.summary.merge_all(key='plot_summaries')
        except:
            pass

        self.saver = tf.train.Saver(max_to_keep=1)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            # intra_op_parallelism_threads=1,
            # inter_op_parallelism_threads=1,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)
            self.saver.restore(self.session, self.ckpt_path)
            log.info(
                "Loaded the pretrain parameters from the provided checkpoint path"
            )

    def train(self):
        log.infov("Training Starts!")
        pprint(self.batch_train)

        max_steps = 50000

        output_save_step = 500

        for s in range(max_steps):
            step, accuracy, summary, loss, step_time = \
                self.run_single_step(self.batch_train, step=s, is_train=True)

            # periodic inference
            accuracy_test = \
                self.run_test(self.batch_test, is_train=False)

            if s % 10 == 0:
                self.log_step_message(step, accuracy, accuracy_test, loss,
                                      step_time)

            self.summary_writer.add_summary(summary, global_step=step)

            if self.session.run(self.global_step) % output_save_step == 0:
                try:
                    save_path = self.saver.save(self.session,
                                                os.path.join(
                                                    self.train_dir, 'model'),
                                                global_step=step)
                    log.infov("Saved checkpoint at %d", s)
                except:
                    log.warning("Error while saving checkpoint. Continuing!")

                if self.session.run(self.global_step) == 50000:
                    import sys
                    sys.exit()

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        batch_chunk = self.session.run(batch)

        fetch = [
            self.global_step, self.model.accuracy, self.summary_op,
            self.model.loss, self.check_op, self.optimizer
        ]

        try:
            if step is not None and (step % 100 == 0):
                fetch += [self.plot_summary_op]
        except:
            pass

        fetch_values = self.session.run(fetch,
                                        feed_dict=self.model.get_feed_dict(
                                            batch_chunk, step=step))
        [step, accuracy, summary, loss] = fetch_values[:4]

        try:
            if self.plot_summary_op in fetch:
                summary += fetch_values[-1]
        except:
            pass

        _end_time = time.time()

        return step, accuracy, summary, loss, (_end_time - _start_time)

    def run_test(self, batch, is_train=False, repeat_times=8):

        batch_chunk = self.session.run(batch)

        accuracy_test = self.session.run(self.model.accuracy,
                                         feed_dict=self.model.get_feed_dict(
                                             batch_chunk, is_training=False))

        return accuracy_test

    def log_step_message(self,
                         step,
                         accuracy,
                         accuracy_test,
                         loss,
                         step_time,
                         is_train=True):
        if step_time == 0:
            step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((
            " [{split_mode:5s} step {step:4d}] " + "Loss: {loss:.5f} " +
            "Accuracy test: {accuracy:.2f} "
            "Accuracy test: {accuracy_test:.2f} " +
            "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
        ).format(split_mode=(is_train and 'train' or 'val'),
                 step=step,
                 loss=loss,
                 accuracy=accuracy * 100,
                 accuracy_test=accuracy_test * 100,
                 sec_per_batch=step_time,
                 instance_per_sec=self.batch_size / step_time))
Beispiel #4
0
    def __init__(self, config, dataset, dataset_test):
        self.config = config
        hyper_parameter_str = config.dataset_path + '_lr_' + str(
            config.learning_rate)
        self.train_dir = './train_dir/%s-%s-%s-%s' % (
            config.model, config.prefix, hyper_parameter_str,
            time.strftime("%Y%m%d-%H%M%S"))

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        _, self.batch_train = create_input_ops(dataset,
                                               self.batch_size,
                                               is_training=True)
        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.learning_rate = config.learning_rate
        if config.lr_weight_decay:
            self.learning_rate = tf.train.exponential_decay(
                self.learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True,
                name='decaying_learning_rate')

        self.check_op = tf.no_op()

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=20.0,
            name='optimizer_loss')

        self.summary_op = tf.summary.merge_all()
        try:
            import tfplot
            self.plot_summary_op = tf.summary.merge_all(key='plot_summaries')
        except:
            pass

        self.saver = tf.train.Saver(max_to_keep=1)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            # intra_op_parallelism_threads=1,
            # inter_op_parallelism_threads=1,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)
            self.saver.restore(self.session, self.ckpt_path)
            log.info(
                "Loaded the pretrain parameters from the provided checkpoint path"
            )
                    type=str,
                    help='pretrained model checkpoint')
parser.add_argument('--epochs', default=101, type=int, help='train epochs')
parser.add_argument('--train', default=True, type=bool, help='train')
args = parser.parse_args()

save_path = args.save_path + f'{args.message}_{time_str}'

if not os.path.exists(save_path):
    os.mkdir(save_path)
logger = Logger(f'{save_path}/log.log')
logger.Print(args)

train_data, val_data, test_data = load_cisia_surf(train_size=args.batch_size,
                                                  test_size=args.test_size)
model = Model(pretrained=False, num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),
                      lr=0.01,
                      momentum=0.9,
                      weight_decay=5e-4)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

if use_cuda:
    model = model.cuda()
    criterion = criterion.cuda()
eval_history = []

train_loss = []
eval_loss = []
eval_history = []
Beispiel #6
0
class Trainer(object):
    @staticmethod
    def get_model_class(model_name):
        if model_name == 'baseline':
            from model_baseline import Model
        elif model_name == 'rn':
            from model_rn import Model
        elif model_name == 'ilp':
            from model_ilp import Model
        else:
            raise ValueError(model_name)
        return Model

    def __init__(self, config, dataset, dataset_test):
        self.config = config
        hyper_parameter_str = config.dataset_path + '_lr_' + str(
            config.learning_rate)
        self.train_dir = './train_dir/%s-%s-%s-%s' % (
            config.model, config.prefix, hyper_parameter_str,
            time.strftime("%Y%m%d-%H%M%S"))

        self.avgLoss = MovingFn(np.mean, 20)
        self.avgAccuracy = MovingFn(np.mean, 20)
        self.avgAccuracyTest = MovingFn(np.mean, 20)
        self.cnt = 0
        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        _, self.batch_train = create_input_ops(dataset,
                                               self.batch_size,
                                               is_training=True)
        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.learning_rate = config.learning_rate
        self.learning_rate_ilp = config.learning_rate_ilp
        if config.lr_weight_decay:
            self.learning_rate = tf.train.exponential_decay(
                self.learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.9,
                staircase=True,
                name='decaying_learning_rate')
            # self.learning_rate_ilp = tf.train.exponential_decay(
            #     self.learning_rate_ilp,
            #     global_step=self.global_step,
            #     decay_steps=10000,
            #     decay_rate=0.5,
            #     staircase=True,
            #     name='decaying_learning_rate_ilp'
            # )

        self.check_op = tf.no_op()

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=10.,
            name='optimizer_loss')

        self.summary_op = tf.summary.merge_all()
        try:
            import tfplot
            self.plot_summary_op = tf.summary.merge_all(key='plot_summaries')
        except:
            pass

        self.saver = tf.train.Saver(max_to_keep=1000)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            # intra_op_parallelism_threads=1,
            # inter_op_parallelism_threads=1,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)

            # saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
            # saver.restore(sess, "/tmp/model.ckpt")

            self.saver.restore(self.session, self.ckpt_path)
            log.info(
                "Loaded the pretrain parameters from the provided checkpoint path"
            )

    def train(self):
        log.infov("Training Starts!")
        pprint(self.batch_train)

        max_steps = 2000000

        output_save_step = 1000

        for s in xrange(max_steps):
            step, accuracy, summary, loss, step_time,pen = \
                self.run_single_step(self.batch_train, step=s, is_train=True)

            # periodic inference
            accuracy_test = \
                self.run_test(self.batch_test, is_train=False)

            if s % 10 == 0:
                self.log_step_message(pen, step, accuracy, accuracy_test, loss,
                                      step_time)

            # if s%2000==0:
            #     self.session.graph._unsafe_unfinalize()
            #     self.model.mdl.binarize(self.session,r=1.1)

            self.summary_writer.add_summary(summary, global_step=step)

            if s % output_save_step == 0:
                log.infov("Saved checkpoint at %d", s)
                save_path = self.saver.save(self.session,
                                            os.path.join(
                                                self.train_dir, 'model'),
                                            global_step=step)

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        self.cnt += 1

        batch_chunk = self.session.run(batch)

        fetch = [
            self.model.pen, self.global_step, self.model.accuracy,
            self.summary_op, self.model.loss, self.check_op, self.optimizer
        ]

        try:
            if step is not None and (step % 100 == 0):
                fetch += [self.plot_summary_op]
        except:
            pass

        # fd = self.model.get_feed_dict(batch_chunk, step=step)
        # if  self.cnt>5000:
        #     for i in range(1):
        #         self.session.run(
        #             self.optimizer_ilp, feed_dict=fd
        #         )

        for i in range(1):
            fetch_values = self.session.run(fetch,
                                            feed_dict=self.model.get_feed_dict(
                                                batch_chunk, step=step))

        [pen, step, accuracy, summary, loss] = fetch_values[:5]

        # xo,acc = self.session.run( [self.model.XO,self.model.acc], feed_dict=self.model.get_feed_dict(batch_chunk, step=step))
        # errs=np.argmax( batch_chunk['q'][np.argwhere(np.argmax( all_targets,-1 )!=np.argmax( all_preds,-1 )).flatten(),6:],-1 )
        # oks=np.argmax( batch_chunk['q'][np.argwhere(np.argmax( all_targets,-1 )==np.argmax( all_preds,-1 )).flatten(),6:],-1 )

        try:
            if self.plot_summary_op in fetch:
                summary += fetch_values[-1]
        except:
            pass

        _end_time = time.time()

        return step, accuracy, summary, loss, (_end_time - _start_time), pen

    def run_test(self, batch, is_train=False, repeat_times=8):

        batch_chunk = self.session.run(batch)

        accuracy_test = self.session.run(self.model.accuracy,
                                         feed_dict=self.model.get_feed_dict(
                                             batch_chunk, is_training=False))

        return accuracy_test

    def log_step_message(self,
                         pen,
                         step,
                         accuracy,
                         accuracy_test,
                         loss,
                         step_time,
                         is_train=True):
        if step_time == 0:
            step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((
            " [{split_mode:5s} step {step:4d}] " + "Loss: {loss:.5f} " +
            "Accuracy test: {accuracy:.2f} "
            "Accuracy test: {accuracy_test:.2f} " +
            "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
            +
            "Avg:{avgloss:.2f}, {avgaccuracy:.2f}, {avgaccuracytest:.2f},{pens:.2f}"
        ).format(split_mode=(is_train and 'train' or 'val'),
                 step=step,
                 loss=loss,
                 accuracy=accuracy * 100,
                 accuracy_test=accuracy_test * 100,
                 sec_per_batch=step_time,
                 instance_per_sec=self.batch_size / step_time,
                 avgloss=self.avgLoss.add(loss),
                 avgaccuracy=self.avgAccuracy.add(accuracy * 100),
                 avgaccuracytest=self.avgAccuracyTest.add(accuracy_test * 100),
                 pens=pen))
class Evaler(object):

    @staticmethod
    def get_model_class(model_name):
        if model_name == 'baseline':
            from model_baseline import Model
        elif model_name == 'rn':
            from model_rn import Model
        else:
            raise ValueError(model_name)
        return Model

    def __init__(self,
                 config,
                 dataset):
        self.config = config
        self.train_dir = config.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        self.dataset = dataset

        check_data_id(dataset, config.data_id)
        _, self.batch = create_input_ops(dataset, self.batch_size,
                                         data_id=config.data_id,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        tf.set_random_seed(1234)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint_path = config.checkpoint_path
        if self.checkpoint_path is None and self.train_dir:
            self.checkpoint_path = tf.train.latest_checkpoint(self.train_dir)

        if self.checkpoint_path is None:
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            log.info("Checkpoint path : %s", self.checkpoint_path)

        mean_std = np.load('../DatasetCreation/VG/mean_std.npz')
        self.img_mean = mean_std['img_mean']
        self.img_std = mean_std['img_std']
        self.coords_mean = mean_std['coords_mean']
        self.coords_std = mean_std['coords_std']

    def IoU(self,boxA, boxB):
        boxA = boxA.astype(np.float64)
        boxB = boxB.astype(np.float64)

        boxA[:,2] = boxA[:,0] + boxA[:,2]
        boxA[:,3] = boxA[:,1] + boxA[:,3]
        boxB[:,2] = boxB[:,0] + boxB[:,2]
        boxB[:,3] = boxB[:,1] + boxB[:,3]
        # determine the (x, y)-coordinates of the intersection rectangle
        xA = np.maximum(boxA[:,0], boxB[:,0])
        yA = np.maximum(boxA[:,1], boxB[:,1])
        xB = np.minimum(boxA[:,2], boxB[:,2])
        yB = np.minimum(boxA[:,3], boxB[:,3])
        # compute the area of intersection rectangle
        interArea = (xB - xA + 1) * (yB - yA + 1)
        # compute the area of both the prediction and ground-truth
        # rectangles
        boxAArea = (boxA[:,2] - boxA[:,0] + 1) * (boxA[:,3] - boxA[:,1] + 1)
        boxBArea = (boxB[:,2] - boxB[:,0] + 1) * (boxB[:,3] - boxB[:,1] + 1)
        # compute the intersection over union by taking the intersection
        # area and dividing it by the sum of prediction + ground-truth
        # areas - the interesection area
        iou = interArea / (boxAArea + boxBArea - interArea)
        # return the intersection over union value
        iou[iou > 1] = 0
        iou[iou < 0] = 0
        return iou

    def eval_run(self):
        # load checkpoint
        if self.checkpoint_path:
            self.saver.restore(self.session, self.checkpoint_path)
            log.info("Loaded from checkpoint!")

        log.infov("Start 1-epoch Inference and Evaluation")

        log.info("# of examples = %d", len(self.dataset))
        length_dataset = len(self.dataset)

        max_steps = int(length_dataset / self.batch_size) + 1
        log.info("max_steps = %d", max_steps)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord, start=True)

        evaler = EvalManager()
        try:
            for s in xrange(max_steps):
                step, acc, step_time, batch_chunk, prediction_pred, prediction_gt, p_l = self.run_single_step(self.batch)

                question_array = batch_chunk['q']
                answer_array = batch_chunk['a']

                img = batch_chunk['img'][0]
                img *= self.img_std
                img += self.img_mean
                img = img.astype(np.uint8)

                nonrelational_indx = np.argmax(question_array[:,30:],axis=1) < 2
                relational_indx = np.argmax(question_array[:,30:],axis=1) > 1

                relational_pred_ans = prediction_pred[relational_indx]
                relational_ans = answer_array[relational_indx]

                nonrelational_pred_ans = prediction_pred[nonrelational_indx]
                nonrelational_ans = answer_array[nonrelational_indx]

                nonrelational_correct = np.sum( np.argmax(nonrelational_pred_ans,axis=1) == np.argmax(nonrelational_ans,axis=1) )
                relational_correct = np.sum( np.argmax(relational_pred_ans,axis=1) == np.argmax(relational_ans,axis=1) )

                if self.config.location:
                    p_l = p_l
                    p_l *= self.coords_std
                    p_l += self.coords_mean

                    location = batch_chunk['l']
                    location *= self.coords_std
                    location += self.coords_mean

                    iou = self.IoU(p_l,location)
                    print iou
                    r_iou = iou[relational_indx].tolist()
                    nr_iou = iou[nonrelational_indx].tolist()
                    print r_iou,nr_iou,relational_indx, nonrelational_indx
                    print 'IoU:',np.mean(iou)
                else:
                    r_iou, nr_iou = 0,0

                evaler.add_batch([relational_correct,len(relational_ans)], [nonrelational_correct, len(nonrelational_ans)],r_iou,nr_iou)

                if self.config.visualize:
                    q = np.argmax(question_array[0][30:])
                    a = np.argmax(answer_array[0])
                    p_a = np.argmax(prediction_pred[0])
                    obj = np.argmax(question_array[0][:15])

                    visualize_prediction(img, q, ans_look_up[a], ans_look_up[p_a], location[0],p_l[0],obj_look_up[obj], id=s)

                self.log_step_message(s, acc, step_time)

        except Exception as e:
            coord.request_stop(e)

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))

        evaler.report()
        log.infov("Evaluation complete.")

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        batch_chunk = self.session.run(batch)
        if self.config.location:
            [step, accuracy, all_preds, rpred, all_targets, _] = self.session.run(
                [self.global_step, self.model.accuracy, self.model.all_preds, self.model.rpred, self.model.a, self.step_op],
                feed_dict=self.model.get_feed_dict(batch_chunk)
            )
            _end_time = time.time()
            return step, accuracy, (_end_time - _start_time), batch_chunk, all_preds, all_targets, rpred
        else:
            [step, accuracy, all_preds, all_targets, _] = self.session.run(
                [self.global_step, self.model.accuracy, self.model.all_preds, self.model.a, self.step_op],
                feed_dict=self.model.get_feed_dict(batch_chunk)
            )
            _end_time = time.time()
            return step, accuracy, (_end_time - _start_time), batch_chunk, all_preds, all_targets, 'N/A'

    def log_step_message(self, step, accuracy, step_time, is_train=False):
        if step_time == 0: step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((" [{split_mode:5s} step {step:4d}] " +
                "batch total-accuracy (test): {test_accuracy:.2f}% " +
                "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
                ).format(split_mode=(is_train and 'train' or 'val'),
                         step=step,
                         test_accuracy=accuracy*100,
                         sec_per_batch=step_time,
                         instance_per_sec=self.batch_size / step_time,
                         )
               )
class Trainer(object):
    @staticmethod
    def get_model_class(model_name):
        if model_name == 'baseline':
            from model_baseline import Model
        elif model_name == 'relational_network':
            from model_rn import Model
        elif model_name == 'attentional_relational_network':
            from model_attentional_rn import Model
        else:
            raise ValueError(model_name)
        return Model

    def __init__(self, config, dataset_train, dataset_val, dataset_test):
        self.config = config
        hyper_parameter_str = config.dataset_path + '_lr_' + str(
            config.learning_rate)
        self.train_dir = './train_dir/%s-%s-%s-%s' % (
            config.model, config.prefix, hyper_parameter_str,
            time.strftime("%Y%m%d-%H%M%S"))

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)
        log.infov("Train Dir: %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        _, self.batch_train = create_input_ops(dataset_train,
                                               self.batch_size,
                                               is_training=True)

        _, self.batch_val = create_input_ops(dataset_val,
                                             self.batch_size,
                                             is_training=False)

        _, self.batch_test = create_input_ops(dataset_test,
                                              self.batch_size,
                                              is_training=False)
        self.train_length = len(dataset_train)
        self.val_length = len(dataset_val)
        self.test_length = len(dataset_test)

        # --- create model ---
        Model = self.get_model_class(config.model)
        log.infov("Using Model class : %s", Model)
        self.model = Model(config)

        # --- optimizer ---
        self.global_step = tf.contrib.framework.get_or_create_global_step(
            graph=None)
        self.learning_rate = config.learning_rate
        if config.lr_weight_decay:
            self.learning_rate = tf.train.exponential_decay(
                self.learning_rate,
                global_step=self.global_step,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True,
                name='decaying_learning_rate')

        self.check_op = tf.no_op()

        self.optimizer = tf.contrib.layers.optimize_loss(
            loss=self.model.loss,
            global_step=self.global_step,
            learning_rate=self.learning_rate,
            optimizer=tf.train.AdamOptimizer,
            clip_gradients=20.0,
            name='optimizer_loss')

        self.summary_op = tf.summary.merge_all()
        try:
            import tfplot
            self.plot_summary_op = tf.summary.merge_all(key='plot_summaries')
        except:
            pass

        self.saver = tf.train.Saver(max_to_keep=5)
        self.best_val_saver = tf.train.Saver(max_to_keep=5)
        self.summary_writer = tf.summary.FileWriter(self.train_dir)

        self.checkpoint_secs = 600  # 10 min

        self.supervisor = tf.train.Supervisor(
            logdir=self.train_dir,
            is_chief=True,
            saver=None,
            summary_op=None,
            summary_writer=self.summary_writer,
            save_summaries_secs=300,
            save_model_secs=self.checkpoint_secs,
            global_step=self.global_step,
        )

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            # intra_op_parallelism_threads=1,
            # inter_op_parallelism_threads=1,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = self.supervisor.prepare_or_wait_for_session(
            config=session_config)

        self.ckpt_path = config.checkpoint
        if self.ckpt_path is not None:
            log.info("Checkpoint path: %s", self.ckpt_path)
            self.saver.restore(self.session, self.ckpt_path)
            log.info(
                "Loaded the pretrain parameters from the provided checkpoint path"
            )

    def train(self):
        log.infov("Training Starts!")
        pprint(self.batch_train)

        step = 0
        output_save_step = 1000
        epoch_train_iter = int(self.train_length / self.batch_size)  # * 10
        epoch_val_iter = int(self.val_length / self.batch_size)  # * 10
        total_epochs = int(200000 / epoch_train_iter)

        best_val_accuracy = 0.
        with tqdm.tqdm(total=total_epochs) as epoch_bar:

            for e in range(total_epochs):
                train_loss = []
                train_accuracy = []
                val_loss = []
                val_accuracy = []
                total_train_time = []
                with tqdm.tqdm(total=epoch_train_iter) as train_bar:
                    for train_step in range(epoch_train_iter):
                        step, accuracy, summary, loss, step_time = \
                            self.run_single_step(self.batch_train, step=step, is_train=True)
                        step += 1
                        train_loss.append(loss)
                        train_accuracy.append(accuracy)
                        total_train_time.append(step_time)
                        train_bar.update(1)
                        train_bar.set_description(
                            "Train loss: {train_loss}, Train accuracy: {train_accuracy},"
                            "Train loss mean: {train_loss_mean}, "
                            "Train accuracy mean: {train_accuracy_mean}".
                            format(
                                train_loss=loss,
                                train_accuracy=accuracy,
                                train_loss_mean=np.mean(train_loss),
                                train_accuracy_mean=np.mean(train_accuracy)))

                    train_loss_mean = np.mean(train_loss)
                    train_loss_std = np.std(train_loss)
                    train_accuracy_mean = np.mean(train_accuracy)
                    train_accuracy_std = np.std(train_accuracy)
                    total_train_time = np.sum(total_train_time)

                with tqdm.tqdm(total=epoch_val_iter) as val_bar:
                    for val_iters in range(epoch_val_iter):

                        loss, accuracy = \
                            self.run_test(self.batch_val, is_train=False)
                        val_loss.append(loss)
                        val_accuracy.append(accuracy)

                        val_bar.update(1)
                        val_bar.set_description(
                            "Val loss: {val_loss}, Val accuracy: {val_accuracy},"
                            "Val loss mean: {val_loss_mean}, Val accuracy mean: {val_accuracy_mean}"
                            .format(val_loss=loss,
                                    val_accuracy=loss,
                                    val_loss_mean=np.mean(val_loss),
                                    val_accuracy_mean=np.mean(train_accuracy)))

                    val_loss_mean = np.mean(val_loss)
                    val_loss_std = np.std(val_loss)
                    val_accuracy_mean = np.mean(val_accuracy)
                    val_accuracy_std = np.std(val_accuracy)

                if val_accuracy_mean >= best_val_accuracy:
                    best_val_accuracy = val_accuracy_mean
                    val_save_path = self.best_val_saver.save(
                        self.session,
                        os.path.join(self.train_dir, 'model'),
                        global_step=step)
                    print("Saved best val model at", val_save_path)

                self.log_step_message(step,
                                      train_accuracy_mean,
                                      val_loss_mean,
                                      val_accuracy_mean,
                                      train_loss_mean,
                                      total_train_time,
                                      is_train=True)

                self.summary_writer.add_summary(summary, global_step=step)

                log.infov("Saved checkpoint at %d", step)
                save_path = self.saver.save(self.session,
                                            os.path.join(
                                                self.train_dir, 'model'),
                                            global_step=step)
                print("Saved current train model at", save_path)
                epoch_bar.update(1)

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        batch_chunk = self.session.run(batch)

        fetch = [
            self.global_step, self.model.accuracy, self.summary_op,
            self.model.loss, self.check_op, self.optimizer
        ]

        try:
            if step is not None and (step % 100 == 0):
                fetch += [self.plot_summary_op]
        except:
            pass

        fetch_values = self.session.run(fetch,
                                        feed_dict=self.model.get_feed_dict(
                                            batch_chunk,
                                            step=step,
                                            is_training=True))
        [step, accuracy, summary, loss] = fetch_values[:4]

        try:
            if self.plot_summary_op in fetch:
                summary += fetch_values[-1]
        except:
            pass

        _end_time = time.time()

        return step, accuracy, summary, loss, (_end_time - _start_time)

    def run_test(self, batch, is_train=False, repeat_times=8):

        batch_chunk = self.session.run(batch)

        loss, accuracy = self.session.run(
            [self.model.loss, self.model.accuracy],
            feed_dict=self.model.get_feed_dict(batch_chunk, is_training=False))

        return loss, accuracy

    def log_step_message(self,
                         step,
                         train_accuracy,
                         val_loss,
                         val_accuracy,
                         train_loss,
                         step_time,
                         is_train=True):
        if step_time == 0:
            step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((
            " [{split_mode:5s} step {step:4d}] " +
            "Train Loss: {train_loss:.5f} " +
            "Train Accuracy: {train_accuracy:.2f} "
            "Validation Accuracy: {val_accuracy:.2f} " +
            "Validation Loss: {val_loss:.2f} " +
            "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
        ).format(split_mode=(is_train and 'train' or 'val'),
                 step=step,
                 train_loss=train_loss,
                 train_accuracy=train_accuracy * 100,
                 val_accuracy=val_accuracy * 100,
                 val_loss=val_loss,
                 sec_per_batch=step_time,
                 instance_per_sec=8000 / step_time))