Ejemplo n.º 1
0
def _get_data(hyper, args, raw_data_dir_):
    df_train = pd.read_pickle(os.path.join(raw_data_dir_, 'df_train.pkl'))
    df_test = pd.read_pickle(os.path.join(raw_data_dir_, 'df_test.pkl'))
    ret_props = dlc.Properties({
        'df_train': df_train,
        'df_test': df_test,
        'train_seq_fname': 'raw_seq_train.pkl',
        'train_sq_seq_fname': 'raw_seq_sq_train.pkl',
        'valid_seq_fname': 'raw_seq_train.pkl',
        'valid_sq_seq_fname': 'raw_seq_sq_train.pkl',
        'test_seq_fname': 'raw_seq_test.pkl',
        'test_sq_seq_fname': 'raw_seq_sq_test.pkl'
    })

    if os.path.exists(os.path.join(raw_data_dir_, 'df_valid.pkl')):
        df_valid = pd.read_pickle(os.path.join(raw_data_dir_, 'df_valid.pkl'))
        ret_props.valid_seq_fname = 'raw_seq_valid.pkl'
        ret_props.valid_sq_seq_fname = 'raw_seq_sq_valid.pkl'
        hyper.logger.info('Loaded dataframes from %s df_train.shape=%s, df_valid.shape=%s, df_test.shape=%s' % (
            raw_data_dir_, df_train.shape, df_valid.shape, df_test.shape))
    else:
        hyper.logger.warn("Didn't find df_valid.pkl. Will split df_train into df_train and df_valid.")
        df_train, df_valid = split_dataset(df_train,
                                           hyper.data_reader_B,
                                           hyper.logger,
                                           args,
                                           hyper.assert_whole_batch,
                                           validation_frac=args.valid_frac)
        hyper.logger.info('Split dataframes from %s df_train.shape=%s, df_valid.shape=%s, df_test.shape=%s' % (
            raw_data_dir_, df_train.shape, df_valid.shape, df_test.shape))
    ret_props.df_valid = df_valid

    return ret_props
Ejemplo n.º 2
0
    def __next__(self):
        nxt = ShuffleIterator.__next__(self)
        # df_batch = nxt.df_batch[['image', 'bin_len', 'seq_len', 'squashed_len']]
        # a_batch = [
        #     self._image_processor.get_array(row[0]) for row in df_batch.itertuples(index=False)
        # ]
        # a_batch = np.asarray(a_batch)
        df_batch = nxt.df_batch[['image', 'height', 'width', 'bin_len', 'seq_len', 'squashed_len']]
        im_batch = [
            self._image_processor.get_array(row[0], row[1], row[2], self._padded_im_dim)
            for row in df_batch.itertuples(index=False)
        ]
        im_batch = self._image_processor.whiten(np.asarray(im_batch))

        bin_len = df_batch.bin_len.iloc[0]
        y_ctc = np.asarray(self._ctc_seq_data[bin_len].loc[df_batch.index].values, dtype=self._hyper.int_type_np)
        ctc_len = np.asarray(df_batch.squashed_len.values, dtype=self._hyper.int_type_np)
        if self._hyper.squash_input_seq:
            y_s = y_ctc
            seq_len = ctc_len
        else:
            y_s   = np.asarray(self._seq_data[bin_len].loc[df_batch.index].values,     dtype=self._hyper.int_type_np)
            seq_len = np.asarray(df_batch.seq_len.values,      dtype=self._hyper.int_type_np)

        return dlc.Properties({'im':im_batch,
                               'y_s':y_s, # (B,T)
                               'seq_len': seq_len, #(B,)
                               ## 'y_s':y_ctc,
                               ## 'seq_len': ctc_len,
                               'y_ctc':y_ctc, #(B,T)
                               'ctc_len': ctc_len, #(B,)
                               'image_name': df_batch.image.values, #(B,)
                               'epoch': nxt.epoch, # scalar
                               'step': nxt.step # scalar
                               })
Ejemplo n.º 3
0
    def __next__(self):
        with self.lock:
            if (self.max_steps >= 0) and (self._step >= self.max_steps):
                raise StopIteration('Max steps executed (%d)'%self._step)

            if self._next_pos >= self._num_items:
                ## Reshuffle sample-to-batch assignment
                self._df = self._df.sample(frac=1)
                ## Reshuffle the bin-batch list
                np.random.shuffle(self._batch_list)
                ## Shuffle the bin composition too
                self._df = self._df.sample(frac=1)
                self._next_pos = 0
                self._hyper.logger.debug('%s finished epoch %d'%(self._name, self._epoch))
                self._epoch += 1
            curr_pos = self._next_pos
            self._next_pos += 1 # value for next iteration
            epoch= self._epoch
            self._step += 1
            curr_step = self._step

        batch = self._batch_list[curr_pos]
        self._hyper.logger.debug('%s epoch %d, step %d, bin-batch idx %s', self._name, self._epoch, self._step, batch)
        df_bin = self._df[self._df.bin_len == batch[0]]
        assert df_bin.bin_len.iloc[batch[1]*self._batch_size] == batch[0]
        assert df_bin.bin_len.iloc[(batch[1]+1)*self._batch_size-1] == batch[0]
        return dlc.Properties({
                'df_batch': df_bin.iloc[batch[1]*self._batch_size : (batch[1]+1)*self._batch_size],
                'epoch': epoch,
                'step': curr_step,
                'batch_idx': batch
                })
Ejemplo n.º 4
0
    def test_bad_props(self):
        props = {'model_name': 'im2latex', 'num_layers': None, 'unset': None}
        open = dlc.Properties(props)
        sealed = dlc.Properties(open).seal()
        props['num_layers'] = 10
        frozen = dlc.Properties(props).freeze()

        self.assertRaises(dlc.AccessDeniedError, setattr, sealed, "x",
                          "MyNeuralNetwork")
        self.assertRaises(dlc.AccessDeniedError, self.dictSet, sealed, "x",
                          "MyNeuralNetwork")
        self.assertRaises(dlc.AccessDeniedError, setattr, frozen, "name",
                          "MyNeuralNetwork")
        self.assertRaises(dlc.AccessDeniedError, self.dictSet, frozen, "name",
                          "MyNeuralNetwork")

        self.assertRaises(KeyError, getattr, sealed, "x")
        self.assertRaises(KeyError, self.dictGet, sealed, "x")
Ejemplo n.º 5
0
    def test_good_props(self):
        props = {'model_name': 'im2latex', 'num_layers': None, 'unset': None}
        open = dlc.Properties(props)
        sealed = dlc.Properties(open).seal()
        props['num_layers'] = 10
        frozen = dlc.Properties(props).freeze()

        open.layer_type = 'MLP'  # create new property
        self.assertEqual(open.layer_type, 'MLP')
        self.assertEqual(open['layer_type'], 'MLP')
        open['layer_type'] = 'CNN'
        self.assertEqual(open.layer_type, 'CNN')
        self.assertEqual(open['layer_type'], 'CNN')

        self.assertEqual(frozen.model_name, 'im2latex')
        self.assertEqual(frozen.unset, None)
        self.assertEqual(frozen['unset'], None)
        self.assertEqual(frozen['num_layers'], 10)
        self.assertEqual(frozen.num_layers, 10)
Ejemplo n.º 6
0
    def __next__(self):
        nxt = ShuffleIterator.__next__(self)
        df_batch = nxt.df_batch[['image', 'height', 'width']]
        im_batch = [
            self._image_processor.get_array(os.path.join(self._image_dir, row[0]), row[1], row[2], self._padded_im_dim)
            for row in df_batch.itertuples(index=False)
        ]
        im_batch = self._image_processor.whiten(np.asarray(im_batch))

        return dlc.Properties({'im':im_batch,
                               'image_name': df_batch.image.values,
                               'epoch': nxt.epoch,
                               'step': nxt.step
                               })
Ejemplo n.º 7
0
    def __next__(self):
        nxt = ShuffleIterator.__next__(self)
        df_batch = nxt.df_batch[['image', 'height', 'width', 'bin_len', 'seq_len']]
        im_batch = [
            self._image_processor.get_array(os.path.join(self._image_dir, row[0]), row[1], row[2], self._padded_im_dim)
            for row in df_batch.itertuples(index=False)
        ]
        im_batch = self._image_processor.whiten(np.asarray(im_batch))

        bin_len = df_batch.bin_len.iloc[0]
        y_s = self._seq_data[bin_len].loc[df_batch.index].values
        return dlc.Properties({'im':im_batch,
                               'y_s':y_s,
                               'seq_len': df_batch.seq_len.values,
                               'image_name': df_batch.image.values,
                               'epoch': nxt.epoch,
                               'step': nxt.step
                               })
Ejemplo n.º 8
0
def make_hyper(initVals={}, freeze=True):
    initVals = dlc.Properties(initVals)
    ## initVals.image_frame_width = 1
    globals = GlobalParams(initVals)
    initVals.update(globals)

    # assert (globals.rLambda == 0) or (globals.dropout is None), 'Both dropouts and weights_regularizer are non-None'

    CALSTM_1 = CALSTMParams(initVals).freeze()
    # CALSTM_2 = CALSTM_1.copy({'m':CALSTM_1.decoder_lstm.layers_units[-1]}).freeze()
    # CALSTM_2 = CALSTMParams(initVals.copy().updated({'m':CALSTM_1.decoder_lstm.layers_units[-1]})).freeze()

    if globals.build_image_context != 2:
        CONVNET = None
    else:
        convnet_common = {
            'weights_initializer': globals.weights_initializer,
            'biases_initializer': globals.biases_initializer,
            'weights_regularizer': globals.weights_regularizer,
            'biases_regularizer': globals.biases_regularizer,
            'activation_fn': tf.nn.tanh,  # vgg16 has tf.nn.relu
            'padding': 'SAME',
        }

        # CONVNET = ConvStackParams({
        #     'op_name': 'Convnet',
        #     'tb': globals.tb,
        #     'layers': (
        #         ConvLayerParams(convnet_common).updated({'output_channels': 256, 'kernel_shape':(3,3), 'stride':(1,1), 'padding':'VALID'}).freeze(),
        #         # ConvLayerParams(convnet_common).updated({'output_channels': 512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
        #         MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
        #
        #         ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
        #         MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
        #
        #         ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
        #         MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
        #
        #         ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
        #         MaxpoolParams(convnet_common).updated({'kernel_shape': (2, 2), 'stride': (2, 2)}).freeze(),
        #
        #         ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
        #         MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
        #     )
        # }).freeze()

        CONVNET = ConvStackParams({
            'op_name':
            'Convnet',
            'tb':
            globals.tb,
            'layers': (
                ConvLayerParams(convnet_common).updated({
                    'output_channels': 64,
                    'kernel_shape': (3, 3),
                    'stride': (1, 1),
                    'padding': 'VALID'
                }).freeze(),
                # ConvLayerParams(convnet_common).updated({'output_channels': 64, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
                MaxpoolParams(convnet_common).updated({
                    'kernel_shape': (2, 2),
                    'stride': (2, 2)
                }).freeze(),
                ConvLayerParams(convnet_common).updated({
                    'output_channels': 128,
                    'kernel_shape': (3, 3),
                    'stride': (1, 1)
                }).freeze(),
                MaxpoolParams(convnet_common).updated({
                    'kernel_shape': (2, 2),
                    'stride': (2, 2)
                }).freeze(),
                ConvLayerParams(convnet_common).updated({
                    'output_channels': 256,
                    'kernel_shape': (3, 3),
                    'stride': (1, 1)
                }).freeze(),
                MaxpoolParams(convnet_common).updated({
                    'kernel_shape': (2, 2),
                    'stride': (2, 2)
                }).freeze(),
                ConvLayerParams(convnet_common).updated({
                    'output_channels': 512,
                    'kernel_shape': (3, 3),
                    'stride': (1, 1)
                }).freeze(),
                MaxpoolParams(convnet_common).updated({
                    'kernel_shape': (2, 2),
                    'stride': (2, 2)
                }).freeze(),
                ConvLayerParams(convnet_common).updated({
                    'output_channels': 512,
                    'kernel_shape': (3, 3),
                    'stride': (1, 1)
                }).freeze(),
                MaxpoolParams(convnet_common).updated({
                    'kernel_shape': (2, 2),
                    'stride': (2, 2)
                }).freeze(),
            )
        }).freeze()

    # else: ## VGG16 architecture
    #     convnet_common = {
    #         'weights_initializer': globals.weights_initializer,
    #         'biases_initializer': globals.biases_initializer,
    #         'weights_regularizer': globals.weights_regularizer,
    #         'biases_regularizer': globals.biases_regularizer,
    #         'activation_fn': tf.nn.relu,
    #         'padding': 'SAME',
    #     }
    #     ## Conv and Maxpool architecture lifted from VGG16
    #     CONVNET = ConvStackParams({
    #         'op_name': 'Convnet',
    #         'tb': globals.tb,
    #         'layers': (
    #             ConvLayerParams(convnet_common).updated({'output_channels':64, 'kernel_shape':(3,3), 'stride':(1,1), 'padding':'VALID'}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':64, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
    #
    #             ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':128, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
    #
    #             ConvLayerParams(convnet_common).updated({'output_channels':256, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':256, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':256, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
    #
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
    #
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             ConvLayerParams(convnet_common).updated({'output_channels':512, 'kernel_shape':(3,3), 'stride':(1,1)}).freeze(),
    #             MaxpoolParams(convnet_common).updated({'kernel_shape':(2,2), 'stride':(2,2)}).freeze(),
    #         )
    #     }).freeze()

    HYPER = Im2LatexModelParams(initVals).updated({
        'CALSTM_STACK': (CALSTM_1, ),
        'CONVNET': CONVNET
    })

    if freeze:
        HYPER.freeze()

    return HYPER
Ejemplo n.º 9
0
def main():
    logger = dtc.makeLogger(set_global=True)

    parser = argparse.ArgumentParser(description='train model')
    parser.add_argument(
        "--num-steps",
        "-n",
        dest="num_steps",
        type=int,
        help=
        "Number of training steps to run. Defaults to -1 if unspecified, i.e. run to completion",
        default=-1)
    parser.add_argument(
        "--num-epochs",
        "-e",
        dest="num_epochs",
        type=int,
        help="Number of training epochs to run. Defaults to 10 if unspecified.",
        default=10)
    parser.add_argument(
        "--batch-size",
        "-b",
        dest="batch_size",
        type=int,
        help=
        "Batchsize per gpu. If unspecified, defaults to the default value in hyper_params",
        default=None)
    parser.add_argument(
        "--seq2seq-beam-width",
        "-w",
        dest="seq2seq_beam_width",
        type=int,
        help="seq2seq Beamwidth. If unspecified, defaults to 10",
        default=10)
    parser.add_argument("--ctc-beam-width",
                        dest="ctc_beam_width",
                        type=int,
                        help="CTC Beamwidth. If unspecified, defaults to 10",
                        default=10)
    parser.add_argument(
        "--print-steps",
        "-s",
        dest="print_steps",
        type=int,
        help=
        "Number of training steps after which to log results. Defaults to 100 if unspecified",
        default=100)
    parser.add_argument("--keep-prob",
                        "-k",
                        dest="keep_prob",
                        type=float,
                        help="Dropout 'keep' probability. Defaults to 0.5",
                        default=1.0)
    parser.add_argument(
        "--adam_alpha",
        "-a",
        dest="alpha",
        type=float,
        help="Alpha (step / learning-rate) value of adam optimizer.",
        default=0.0001)
    parser.add_argument(
        "--r-lambda",
        "-r",
        dest="rLambda",
        type=float,
        help=
        "Sets value of rLambda - lambda value used for regularization. Defaults to 0.00005.",
        default=0.00005)
    parser.add_argument(
        "--data-folder",
        "-d",
        dest="data_folder",
        type=str,
        help="Data folder. If unspecified, defaults to raw_data_folder/..",
        default=None)
    parser.add_argument("--raw-data-folder",
                        dest="raw_data_folder",
                        type=str,
                        help="Raw data folder. Must be specified.",
                        default=None)
    parser.add_argument(
        "--vgg16-folder",
        dest="vgg16_folder",
        type=str,
        help=
        "vgg16 data folder. If unspecified, defaults to data_folder/vgg16_features",
        default=None)
    parser.add_argument(
        "--image-folder",
        dest="image_folder",
        type=str,
        help=
        "image folder. If unspecified, defaults to data_folder/formula_images",
        default=None)
    parser.add_argument(
        "--partial-batch",
        "-p",
        dest="partial_batch",
        action='store_true',
        help=
        "Sets assert_whole_batch hyper param to False. Default value for this option is False "
        " (i.e. assert_whole_batch=True)",
        default=False)
    parser.add_argument(
        "--queue-capacity",
        "-q",
        dest="queue_capacity",
        type=int,
        help=
        "Capacity of input queue. Defaults to hyperparam defaults if unspecified.",
        default=None)
    parser.add_argument(
        "--logging-level",
        "-l",
        dest="logging_level",
        type=int,
        choices=range(1, 6),
        help=
        "Logging verbosity level from 1 to 5 in increasing order of verbosity.",
        default=4)
    parser.add_argument(
        "--valid-frac",
        "-f",
        dest="valid_frac",
        type=float,
        help="Fraction of samples to use for validation. Defaults to 0.05",
        default=0.05)
    parser.add_argument(
        "--validation-epochs",
        "-v",
        dest="valid_epochs",
        type=float,
        help=
        """Number (or fraction) of epochs after which to run a full validation cycle. For this
                             behaviour, the number should be greater than 0. A value less <= 0 on the other hand,
                             implies 'smart' validation - which will result in selectively 
                             capturing snapshots around max_scoring peaks (based on training/bleu2).""",
        default=1.0)
    parser.add_argument(
        "--save-all-eval",
        dest="save_all_eval",
        action='store_true',
        help=
        "(Boolean): False => Save only one random validation/testing batch, True = > Save all validation/testing batches",
        default=False)
    parser.add_argument(
        "--build-image-context",
        "-i",
        dest="build_image_context",
        type=int,
        help=
        "Sets value of hyper.build_image_context. Default is 2 => build my own convnet.",
        default=2)
    parser.add_argument(
        "--swap-memory",
        dest="swap_memory",
        action='store_true',
        help="swap_memory option of tf.scan and tf.while_loop. Default to False."
        " Enabling allows training larger mini-batches at the cost of speed.",
        default=False)
    parser.add_argument(
        "--restore",
        dest="restore_logdir",
        type=str,
        help=
        "restore from checkpoint. Provide logdir path as argument. Don't specify the --logdir argument.",
        default=None)
    parser.add_argument(
        "--logdir",
        dest="logdir",
        type=str,
        help=
        "(optional) Sets TensorboardParams.tb_logdir. Can't specify the --restore argument along with this.",
        default=None)
    parser.add_argument(
        "--logdir-tag",
        dest="logdir_tag",
        type=str,
        help=
        "(optional) Sets TensorboardParams.logdir_tab. Can't specify the --restore argument along with this.",
        default=None)
    # parser.add_argument("--use-ctc-loss", dest="use_ctc_loss", action='store_true',
    #                     help="Sets the use_ctc_loss hyper parameter. Defaults to False.",
    #                     default=False)
    parser.add_argument(
        "--validate",
        dest="doValidate",
        action='store_true',
        help=
        "Run validation cycle only. --restore option should be provided along with this.",
        default=False)
    parser.add_argument(
        "--test",
        dest="doTest",
        action='store_true',
        help=
        "Run test cycle only but with training dataset. --restore option should be provided along with this.",
        default=False)
    parser.add_argument(
        "--squash-input-seq",
        dest="squash_input_seq",
        action='store_true',
        help=
        "(boolean) Set value of squash_input_seq hyper param. Defaults to True.",
        default=True)
    parser.add_argument(
        "--num-snapshots",
        dest="num_snapshots",
        type=int,
        help=
        "Number of latest snapshots to save. Defaults to 100 if unspecified",
        default=100)

    args = parser.parse_args()

    raw_data_folder = args.raw_data_folder
    if args.data_folder:
        data_folder = args.data_folder
    else:
        data_folder = os.path.join(raw_data_folder, '..')

    if args.image_folder:
        image_folder = args.image_folder
    else:
        image_folder = os.path.join(data_folder, 'formula_images')

    if args.vgg16_folder:
        vgg16_folder = args.vgg16_folder
    else:
        vgg16_folder = os.path.join(data_folder, 'vgg16_features')

    if args.restore_logdir is not None:
        assert args.logdir is None, 'Only one of --restore-logdir and --logdir can be specified.'
        assert args.logdir_tag is None, "--logdir-tag can't be specified alongside --logdir"
        tb = tfc.TensorboardParams({
            'tb_logdir':
            os.path.dirname(args.restore_logdir)
        }).freeze()
    elif args.logdir is not None:
        tb = tfc.TensorboardParams({
            'tb_logdir': args.logdir,
            'logdir_tag': args.logdir_tag
        }).freeze()
    else:
        tb = tfc.TensorboardParams({
            'tb_logdir': './tb_metrics',
            'logdir_tag': args.logdir_tag
        }).freeze()

    if args.doValidate:
        assert args.restore_logdir is not None, 'Please specify --restore option along with --validate'
        assert not args.doTest, '--test and --validate cannot be given together'

    if args.doTest:
        assert args.restore_logdir is not None, 'Please specify --restore option along with --test'
        assert not args.doValidate, '--test and --validate cannot be given together'

    globalParams = dlc.Properties({
        'raw_data_dir':
        raw_data_folder,
        'assert_whole_batch':
        not args.partial_batch,
        'logger':
        logger,
        'tb':
        tb,
        'print_steps':
        args.print_steps,
        'num_steps':
        args.num_steps,
        'num_epochs':
        args.num_epochs,
        'num_snapshots':
        args.num_snapshots,
        'data_dir':
        data_folder,
        'generated_data_dir':
        data_folder,
        'image_dir':
        image_folder,
        'ctc_beam_width':
        args.ctc_beam_width,
        'seq2seq_beam_width':
        args.seq2seq_beam_width,
        'valid_frac':
        args.valid_frac,
        'valid_epochs':
        args.valid_epochs,
        'save_all_eval':
        args.save_all_eval,
        'build_image_context':
        args.build_image_context,
        'sum_logloss':
        False,  # setting to true equalizes ctc_loss and log_loss if y_s == squashed_seq
        'dropout':
        None if args.keep_prob >= 1.0 else tfc.DropoutParams({
            'keep_prob':
            args.keep_prob
        }).freeze(),
        'MeanSumAlphaEquals1':
        False,
        'rLambda':
        args.rLambda,  # 0.0005, 0.00005
        'make_training_accuracy_graph':
        False,
        # 'use_ctc_loss': args.use_ctc_loss,
        "swap_memory":
        args.swap_memory,
        'tf_session_allow_growth':
        False,
        'restore_from_checkpoint':
        args.restore_logdir is not None,
        'num_gpus':
        2,
        'towers_per_gpu':
        1,
        'beamsearch_length_penalty':
        1.0,
        'doValidate':
        args.doValidate,
        'doTest':
        args.doTest,
        'doTrain':
        not (args.doValidate or args.doTest),
        'squash_input_seq':
        args.squash_input_seq,
        'att_model':
        'MLP_full',  # '1x1_conv', 'MLP_shared', 'MLP_full'
        'weights_regularizer':
        tf.contrib.layers.l2_regularizer(scale=1.0, scope='L2_Regularizer'),
        # 'embeddings_regularizer': None,
        # 'outputMLP_skip_connections': False,
        'output_reuse_embeddings':
        False,
        'REGROUP_IMAGE': (4, 1),  # None  # (4,1)
        'build_att_modulator':
        False,  # turn off beta-MLP
        'build_scanning_RNN':
        False,
        'init_model_input_transform':
        'full',
        'build_init_model':
        True,
        'adam_beta1':
        0.5,
        'adam_beta2':
        0.9,
        'pLambda':
        0.0
    })

    if args.batch_size is not None:
        globalParams.B = args.batch_size

    if args.queue_capacity is not None:
        globalParams.input_queue_capacity = args.queue_capacity
    if args.alpha is not None:
        globalParams.adam_alpha = args.alpha

    if args.restore_logdir is not None:
        globalParams.logdir = args.restore_logdir
    else:
        globalParams.logdir = dtc.makeTBDir(tb.tb_logdir, tb.logdir_tag)

    # args
    globalParams.storedir = dtc.makeLogDir(globalParams.logdir, 'store')
    globalParams.dump(dtc.makeLogfileName(globalParams.storedir, 'args.pkl'))

    # Hyper Params
    hyper = hyper_params.make_hyper(globalParams, freeze=False)
    if args.restore_logdir is not None:
        hyper.dump(dtc.makeLogfileName(globalParams.storedir, 'hyper.pkl'))
    else:
        hyper.dump(globalParams.storedir, 'hyper.pkl')

    # Logger
    fh = logging.FileHandler(
        dtc.makeLogfileName(globalParams.storedir, 'training.log'))
    fh.setFormatter(dtc.makeFormatter())
    logger.addHandler(fh)
    dtc.setLogLevel(logger, args.logging_level)

    logger.info(' '.join(sys.argv))
    logger.info(
        '\n#################### Default Param Overrides: ####################\n%s',
        globalParams.pformat())
    logger.info(
        '##################################################################\n')
    logger.info(
        '\n#########################  Hyper-params: #########################\n%s',
        hyper.pformat())
    logger.info(
        '##################################################################\n')

    train_multi_gpu.main(raw_data_folder, vgg16_folder, globalParams,
                         hyper.freeze())
Ejemplo n.º 10
0
def main():
    _data_folder = '../data/dataset3'

    parser = arg.ArgumentParser(description='train model')
    parser.add_argument(
        "--num-steps",
        "-n",
        dest="num_steps",
        type=int,
        help=
        "Number of training steps to run. Defaults to -1 if unspecified, i.e. run to completion",
        default=-1)
    parser.add_argument(
        "--num-epochs",
        "-e",
        dest="num_epochs",
        type=int,
        help="Number of training steps to run. Defaults to 1 if unspecified.",
        default=1)
    parser.add_argument(
        "--batch-size",
        "-b",
        dest="batch_size",
        type=int,
        help=
        "Batchsize. If unspecified, defaults to the default value in hyper_params",
        default=None)
    parser.add_argument(
        "--print-steps",
        "-s",
        dest="print_steps",
        type=int,
        help=
        "Number of training steps after which to log results. Defaults to 10 if unspecified",
        default=100)
    parser.add_argument("--data-folder",
                        "-d",
                        dest="data_folder",
                        type=str,
                        help="Data folder. If unspecified, defaults to " +
                        _data_folder,
                        default=_data_folder)
    parser.add_argument(
        "--raw-data-folder",
        dest="raw_data_folder",
        type=str,
        help=
        "Raw data folder. If unspecified, defaults to data_folder/training",
        default=None)
    parser.add_argument(
        "--vgg16-folder",
        dest="vgg16_folder",
        type=str,
        help=
        "vgg16 data folder. If unspecified, defaults to data_folder/vgg16_features",
        default=None)
    parser.add_argument(
        "--image-folder",
        dest="image_folder",
        type=str,
        help=
        "image folder. If unspecified, defaults to data_folder/formula_images",
        default=None)
    parser.add_argument(
        "--partial-batch",
        "-p",
        dest="partial_batch",
        action='store_true',
        help=
        "Sets assert_whole_batch hyper param to False. Default hyper_param value will be used if unspecified"
    )
    parser.add_argument(
        "--logging-level",
        "-l",
        dest="logging_level",
        type=int,
        help=
        "Logging verbosity level from 1 to 5 in increasing order of verbosity.",
        default=4)

    args = parser.parse_args()
    data_folder = args.data_folder
    params = dlc.Properties({
        'num_steps':
        args.num_steps,
        'print_steps':
        args.print_steps,
        'num_epochs':
        args.num_epochs,
        'logger':
        dtc.makeLogger(args.logging_level, set_global=True),
        'build_image_context':
        1,
        'weights_regularizer':
        None,
        'num_gpus':
        1,
        'tb':
        tfc.TensorboardParams({
            'tb_logdir': 'tb_metrics_convnet'
        }).freeze()
    })
    if args.image_folder:
        params.image_folder = args.image_folder
    else:
        params.image_folder = os.path.join(data_folder, 'formula_images')

    if args.raw_data_folder:
        params.raw_data_folder = args.raw_data_folder
    else:
        params.raw_data_folder = os.path.join(data_folder, 'training')
    params.raw_data_dir = params.raw_data_folder

    if args.vgg16_folder:
        params.vgg16_folder = args.vgg16_folder
    else:
        params.vgg16_folder = os.path.join(data_folder, 'vgg16_features')

    if args.batch_size is not None:
        params.B = args.batch_size
    if args.partial_batch:
        params.assert_whole_batch = False

    data_props = dtc.load(params.raw_data_folder, 'data_props.pkl')
    params.image_shape = (data_props['padded_image_dim']['height'],
                          data_props['padded_image_dim']['width'], 3)
    run_convnet(params)
Ejemplo n.º 11
0
def evaluate(session, ops, batch_its, hyper, args, step, num_steps, tf_sw, training_logic):
    training_logic._set_flags_after_validation()
    eval_start_time = time.time()

    eval_ops = ops.eval_ops
    batch_it = batch_its.eval_it
    batch_size = batch_it.batch_size
    # Print a batch randomly
    print_batch_num = np.random.randint(1, num_steps+1) if not args.save_all_eval else None
    accum = Accumulator()
    n = 0
    hyper.logger.info('evaluation cycle starting at step %d for %d steps', step, num_steps)
    while n < num_steps:
        n += 1
        if (n == print_batch_num) or args.save_all_eval:
            (l, mean_ed, accuracy, num_hits, top1_ids_list, top1_lens, y_ctc_list, ctc_len, y_s_list, top1_alpha_list,
             top1_beta_list, image_name_list, top1_ed) = session.run(
                (
                    eval_ops.top1_len_ratio,
                    eval_ops.top1_mean_ed,
                    eval_ops.top1_accuracy,
                    eval_ops.top1_num_hits,
                    eval_ops.top1_ids_list,
                    eval_ops.top1_lens,
                    eval_ops.y_ctc_list,
                    eval_ops.ctc_len,
                    eval_ops.y_s_list,
                    eval_ops.top1_alpha_list,
                    eval_ops.top1_beta_list,
                    eval_ops.image_name_list,
                    eval_ops.top1_ed
                ))
            if args.save_all_eval:
                accum.extend({'y': y_s_list,
                              'top1_ids': top1_ids_list,
                              'alpha': top1_alpha_list,
                              'beta': top1_beta_list,
                              'image_name': image_name_list
                              })
                accum.append({'ed': top1_ed})
        else:
            l, mean_ed, accuracy, num_hits, top1_ids_list, top1_lens, y_ctc_list, ctc_len = session.run((
                                eval_ops.top1_len_ratio,
                                eval_ops.top1_mean_ed,
                                eval_ops.top1_accuracy,
                                eval_ops.top1_num_hits,
                                eval_ops.top1_ids_list,
                                eval_ops.top1_lens,
                                eval_ops.y_ctc_list,
                                eval_ops.ctc_len
                                ))
            y_s_list = top1_alpha_list = top1_beta_list = image_name_list = top1_ed = None

        bleu = sentence_bleu_scores(hyper, top1_ids_list, top1_lens, y_ctc_list, ctc_len)
        accum.append({'bleus': bleu})
        accum.extend({'sq_predicted_ids': squashed_seq_list(hyper, top1_ids_list, top1_lens)})
        accum.extend({'trim_target_ids': trimmed_seq_list(hyper, y_ctc_list, ctc_len)})
        accum.append({'len_ratio': l})
        accum.append({'mean_eds': mean_ed})
        accum.append({'accuracies': accuracy})
        accum.append({'hits': num_hits})

        if n == print_batch_num:
            # logger.info('############ RANDOM VALIDATION BATCH %d ############', n)
            # logger.info('prediction mean_ed=%f', mean_ed)
            # logger.info('prediction accuracy=%f', accuracy)
            # logger.info('prediction hits=%d', num_hits)
            # bleu = sentence_bleu_scores(hyper, top1_ids_list, top1_lens, y_ctc_list, ctc_len)

            with dtc.Storer(args, 'test' if args.doTest else 'validation', step) as storer:
                storer.write('predicted_ids', top1_ids_list, np.int16)
                storer.write('y', y_s_list, np.int16)
                storer.write('alpha', top1_alpha_list, dtype=np.float32, batch_axis=1)
                storer.write('beta', top1_beta_list, dtype=np.float32, batch_axis=1)
                storer.write('image_name', image_name_list, dtype=np.unicode_)
                storer.write('ed', top1_ed, dtype=np.float32)
                storer.write('bleu', bleu, dtype=np.float32)
            # logger.info( '############ END OF RANDOM VALIDATION BATCH ############')

    agg_bleu2 = dlc.corpus_bleu_score(accum.sq_predicted_ids, accum.trim_target_ids)
    eval_time_per100 = (time.time() - eval_start_time) * 100. / (num_steps * batch_size)

    if args.save_all_eval:
        assert len(accum.y) == len(accum.top1_ids) == len(accum.alpha) == len(accum.beta) == len(accum.image_name)
        assert len(accum.bleus) == len(accum.ed), 'len bleus = %d, len ed = %d' % (len(accum.bleus), len(accum.ed),)
        with dtc.Storer(args, 'test' if args.doTest else 'validation', step) as storer:
            storer.write('predicted_ids', accum.top1_ids, np.int16)
            storer.write('y', accum.y, np.int16)
            storer.write('alpha', accum.alpha, dtype=np.float32, batch_axis=1)
            storer.write('beta', accum.beta, dtype=np.float32, batch_axis=1)
            storer.write('image_name', accum.image_name, dtype=np.unicode_)
            storer.write('ed', accum.ed, dtype=np.float32)
            storer.write('bleu', accum.bleus, dtype=np.float32)

    logs_agg_top1 = session.run(eval_ops.logs_agg_top1,
                                feed_dict={
                                    eval_ops.ph_top1_len_ratio: accum.len_ratio,
                                    eval_ops.ph_edit_distance: accum.mean_eds,
                                    eval_ops.ph_num_hits: accum.hits,
                                    eval_ops.ph_accuracy: accum.accuracies,
                                    eval_ops.ph_valid_time: eval_time_per100,
                                    eval_ops.ph_bleus: accum.bleus,
                                    eval_ops.ph_bleu2: agg_bleu2,
                                    eval_ops.ph_full_validation: 1 if (num_steps == batch_it.epoch_size) else 0
                                })
    with dtc.Storer(args, 'metrics_test' if args.doTest else 'metrics_validation', step) as storer:
        storer.write('edit_distance', np.mean(accum.mean_eds, keepdims=True), np.float32)
        storer.write('num_hits', np.sum(accum.hits, keepdims=True), dtype=np.uint32)
        storer.write('accuracy', np.mean(accum.accuracies, keepdims=True), np.float32)
        storer.write('bleu', np.mean(accum.bleus, keepdims=True), dtype=np.float32)
        storer.write('bleu2', np.asarray([agg_bleu2]), dtype=np.float32)

    tf_sw.add_summary(logs_agg_top1, standardized_step(step))
    tf_sw.flush()
    hyper.logger.info('validation cycle finished. bleu2 = %f', agg_bleu2)
    return dlc.Properties({'eval_time_per100': eval_time_per100})
Ejemplo n.º 12
0
def evaluate_scanning_RNN(args, hyper, session, ops, ops_accum, ops_log, tr_step, num_steps, tf_sw, training_logic):
    training_logic._set_flags_after_validation()
    start_time = time.time()
    ############################# Validation Or Test Cycle ##############################
    hyper.logger.info('validation cycle starting at step %d for %d steps', tr_step, num_steps)
    accum = Accumulator()
    print_batch_num = np.random.randint(1, num_steps + 1) if not args.save_all_eval else None
    for batch in range(1, 1+num_steps):
        step_start_time = time.time()
        doLog = (print_batch_num == batch)
        if doLog or args.save_all_eval:
            batch_ops = TFRun(session, ops, ops_accum + ops_log, None)
        else:
            batch_ops = TFRun(session, ops, ops_accum, None)

        batch_ops.run_ops()

        ## Accumulate Metrics
        accum.append(batch_ops)
        batch_time = time.time() - step_start_time
        bleu = sentence_bleu_scores(hyper,
                                    batch_ops.predicted_ids_list,
                                    batch_ops.predicted_lens,
                                    batch_ops.y_ctc_list,
                                    batch_ops.ctc_len)
        accum.extend({'bleu_scores': bleu})
        accum.extend({
            'sq_predicted_ids': squashed_seq_list(hyper, batch_ops.predicted_ids_list, batch_ops.predicted_lens),
            'sq_y_ctc': trimmed_seq_list(hyper, batch_ops.y_ctc_list, batch_ops.ctc_len),
            'predicted_ids': batch_ops.predicted_ids_list,
            'y': batch_ops.y_s_list,
            'alpha': batch_ops.alpha,
            'beta': batch_ops.beta,
            'image_name': batch_ops.image_name_list,
            'ctc_ed': batch_ops.ctc_ed,
            'bin_len': batch_ops.bin_len,
            'scan_len': batch_ops.scan_len,
            'x': batch_ops.x_s_list
        })
        accum.append({'batch_time': batch_time})

        if doLog:
            with dtc.Storer(args, 'test' if args.doTest else 'validation', tr_step) as storer:
                storer.write('predicted_ids', batch_ops.predicted_ids_list, np.int16)
                storer.write('y', batch_ops.y_s_list, np.int16)
                storer.write('alpha', batch_ops.alpha, np.float32, batch_axis=1)
                storer.write('beta', batch_ops.beta, np.float32, batch_axis=1)
                storer.write('image_name', batch_ops.image_name_list, dtype=np.unicode_)
                storer.write('ed', batch_ops.ctc_ed, np.float32)
                storer.write('bleu', bleu, np.float32)
                storer.write('bin_len', batch_ops.bin_len, np.float32)
                storer.write('scan_len', batch_ops.scan_len, np.float32)
                storer.write('x', batch_ops.x_s_list, np.float32)

    if args.save_all_eval:
        with dtc.Storer(args, 'test' if args.doTest else 'validation', tr_step) as storer:
            storer.write('predicted_ids', accum.predicted_ids, np.int16)
            storer.write('y', accum.y, np.int16)
            storer.write('alpha', accum.alpha, np.float32, batch_axis=1)
            storer.write('beta', accum.beta, np.float32, batch_axis=1)
            storer.write('image_name', accum.image_name, dtype=np.unicode_)
            storer.write('ed', accum.ctc_ed, np.float32)
            storer.write('bleu', accum.bleu_scores, np.float32)
            storer.write('bin_len', accum.bin_len, np.float32)
            storer.write('scan_len', accum.scan_len, np.float32)
            storer.write('x', accum.x, np.float32)

    # Calculate aggregated validation metrics
    eval_time_per100 = np.mean(batch_time) * 100. / hyper.data_reader_B
    tb_agg_logs = session.run(ops.tb_agg_logs, feed_dict={
        ops.ph_train_time: eval_time_per100,
        ops.ph_bleu_scores: accum.bleu_scores,
        ops.ph_bleu_score2: dlc.corpus_bleu_score(accum.sq_predicted_ids, accum.sq_y_ctc),
        ops.ph_ctc_eds: accum.ctc_ed,
        ops.ph_loglosses: accum.log_likelihood,
        ops.ph_ctc_losses: accum.ctc_loss,
        ops.ph_alpha_penalties: accum.alpha_penalty,
        ops.ph_costs: accum.cost,
        ops.ph_mean_norm_ases: accum.mean_norm_ase,
        ops.ph_mean_norm_aaes: accum.mean_norm_aae,
        ops.ph_beta_mean: accum.beta_mean,
        ops.ph_beta_std_dev: accum.beta_std_dev,
        ops.ph_pred_len_ratios: accum.pred_len_ratio,
        ops.ph_num_hits: accum.num_hits,
        ops.ph_reg_losses: accum.reg_loss,
        ops.ph_scan_lens: accum.scan_len
    })
    tf_sw.add_summary(tb_agg_logs, global_step=standardized_step(tr_step))
    tf_sw.flush()
    return dlc.Properties({'eval_time_per100': eval_time_per100})
Ejemplo n.º 13
0
def main(raw_data_folder,
          vgg16_folder,
          args,
          hyper):
    """
    Start training the model.
    """
    dtc.initialize(args.raw_data_dir, hyper)
    global logger
    logger = hyper.logger
    global standardized_step
    standardized_step = make_standardized_step(hyper)

    graph = tf.Graph()
    with graph.as_default():
        if hyper.build_image_context == 1:
            train_it, eval_it = create_imagenet_iterators(raw_data_folder,
                                                          hyper,
                                                          args)
        elif hyper.build_image_context == 2:
            train_it, eval_it = create_BW_image_iterators(raw_data_folder,
                                                          hyper,
                                                          args)
        else:
            train_it, eval_it = create_context_iterators(raw_data_folder,
                                                         vgg16_folder,
                                                         hyper,
                                                         args)

        qrs = []
        ##### Training Graphs
        train_tower_ops = []; train_ops = None
        trainable_vars_n = 0
        toplevel_var_scope = tf.get_variable_scope()
        reuse_vars = False
        with tf.name_scope('Training'):
            tf_train_step = tf.get_variable('global_step', dtype=hyper.int_type, trainable=False, initializer=0)
            if hyper.optimizer == 'adam':
                opt = tf.train.AdamOptimizer(learning_rate=hyper.adam_alpha, beta1=hyper.adam_beta1, beta2=hyper.adam_beta2)
            else:
                raise Exception('Unsupported optimizer - %s - configured.' % (hyper.optimizer,))

            if args.doTrain:
                with tf.variable_scope('InputQueue'):
                    train_q = tf.FIFOQueue(hyper.input_queue_capacity, train_it.out_tup_types)
                    tf_enqueue_train_queue = train_q.enqueue_many(train_it.get_pyfunc_with_split(hyper.num_towers))
                    tf_close_train_queue = train_q.close(cancel_pending_enqueues=True)
                for i in range(hyper.num_gpus):
                    for j in range(hyper.towers_per_gpu):
                        with tf.name_scope('gpu_%d'%i + ('/tower_%d'%j if hyper.towers_per_gpu > 1 else '')):
                            with tf.device('/gpu:%d'%i):
                                model = Im2LatexModel(hyper, train_q, opt=opt, reuse=reuse_vars)
                                train_tower_ops.append(model.build_training_tower())
                                if not reuse_vars:
                                    trainable_vars_n = num_trainable_vars()  # 8544670 or 8547670
                                    hyper.logger.info('Num trainable variables = %d', trainable_vars_n)
                                    reuse_vars = True
                                    ## assert trainable_vars_n == 8547670 if hyper.use_peephole else 8544670
                                    ## assert trainable_vars_n == 23261206 if hyper.build_image_context
                                else:
                                    assert num_trainable_vars() == trainable_vars_n, 'trainable_vars %d != expected %d'%(num_trainable_vars(), trainable_vars_n)
                train_ops = sync_training_towers(hyper, train_tower_ops, tf_train_step, optimizer=opt)
        if args.doTrain:
            qr1 = tf.train.QueueRunner(train_q, [tf_enqueue_train_queue], cancel_op=[tf_close_train_queue])
            qrs.append(qr1)

        ##### Validation/Testing Graph
        eval_tower_ops = []; eval_ops = None
        if eval_it:  # and (args.doTrain or args.doValidate):
            with tf.name_scope('Validation' if not args.doTest else 'Testing'):
                hyper_predict = hyper_params.make_hyper(args.copy().updated({'dropout':None}))
                with tf.variable_scope('InputQueue'):
                    eval_q = tf.FIFOQueue(hyper.input_queue_capacity, eval_it.out_tup_types)
                    enqueue_op2 = eval_q.enqueue_many(eval_it.get_pyfunc_with_split(hyper.num_towers))
                    close_queue2 = eval_q.close(cancel_pending_enqueues=True)
                for i in range(args.num_gpus):
                    for j in range(args.towers_per_gpu):
                        with tf.name_scope('gpu_%d' % i + ('/tower_%d' % j if hyper.towers_per_gpu > 1 else '')):
                            with tf.device('/gpu:%d'%i):
                                if hyper.build_scanning_RNN:
                                    model_predict = Im2LatexModel(hyper_predict,
                                                                  eval_q,
                                                                  reuse=reuse_vars)
                                    eval_tower_ops.append(model_predict.build_training_tower())
                                else:
                                    model_predict = Im2LatexModel(hyper_predict,
                                                                  eval_q,
                                                                  seq2seq_beam_width=hyper.seq2seq_beam_width,
                                                                  reuse=reuse_vars)
                                    eval_tower_ops.append(model_predict.build_testing_tower())
                                if not reuse_vars:
                                    trainable_vars_n = num_trainable_vars()
                                    reuse_vars = True
                                else:
                                    assert num_trainable_vars() == trainable_vars_n, 'trainable_vars %d != expected %d' % (
                                        num_trainable_vars(), trainable_vars_n)

                hyper.logger.info('Num trainable variables = %d', num_trainable_vars())
                assert num_trainable_vars() == trainable_vars_n, 'num_trainable_vars(%d) != %d'%(num_trainable_vars(), trainable_vars_n)
                if hyper.build_scanning_RNN:
                    eval_ops = sync_training_towers(hyper, eval_tower_ops, global_step=None, run_tag='validation' if not args.doTest else 'testing')
                else:
                    eval_ops = sync_testing_towers(hyper, eval_tower_ops, run_tag='validation' if not args.doTest else 'testing')
            qr2 = tf.train.QueueRunner(eval_q, [enqueue_op2], cancel_op=[close_queue2])
            qrs.append(qr2)

        # ##### Training Accuracy Graph
        # if (args.make_training_accuracy_graph):
        #     with tf.name_scope('TrainingAccuracy'):
        #         hyper_predict2 = hyper_params.make_hyper(args.copy().updated({'dropout':None}))
        #         with tf.device('/gpu:1'):
        #             model_predict2 = Im2LatexModel(hyper_predict, hyper.seq2seq_beam_width, reuse=True)
        #             tr_acc_ops = model_predict2.test()
        #         with tf.variable_scope('QueueOps'):
        #             enqueue_op3 = tr_acc_ops.inp_q.enqueue_many(tr_acc_it.get_pyfunc_with_split(hyper.num_towers))
        #             close_queue3 = tr_acc_ops.inp_q.close(cancel_pending_enqueues=True)
        #         assert(num_trainable_vars() == trainable_vars_n)
        #     qr3 = tf.train.QueueRunner(tr_acc_ops.inp_q, [enqueue_op3], cancel_op=[close_queue3])
        # else:
        tr_acc_ops = None

        coord = tf.train.Coordinator()
        training_logic = TrainingLogic(args, coord, train_it, eval_it)

        # print train_ops

        printVars(logger)

        config=tf.ConfigProto(log_device_placement=False, allow_soft_placement=True)
        config.gpu_options.allow_growth = hyper.tf_session_allow_growth

        with tf.Session(config=config) as session:
            logger.info('Flushing graph to disk')
            tf_sw = tf.summary.FileWriter(args.logdir, graph=graph)
            # tf_params = tf.constant(value=hyper.to_table(), dtype=tf.string, name='hyper_params')

            # tf_text = tf.summary.text('hyper_params_logger', tf_params)
            # log_params = session.run(tf_text)
            # tf_sw.add_summary(log_params, global_step=None)
            tf_sw.flush()

            enqueue_threads = []
            for qr in qrs:
                enqueue_threads.extend(qr.create_threads(session, coord=coord, start=True))
            # enqueue_threads = qr1.create_threads(session, coord=coord, start=True)
            # enqueue_threads.extend(qr2.create_threads(session, coord=coord, start=True))
            # if args.make_training_accuracy_graph:
            #     enqueue_threads.extend(qr3.create_threads(session, coord=coord, start=True))
            logger.info('Created enqueue threads')

            saver = tf.train.Saver(max_to_keep=args.num_snapshots, pad_step_number=True, save_relative_paths=True)
            if args.restore_from_checkpoint:
                latest_checkpoint = tf.train.latest_checkpoint(args.logdir, latest_filename='checkpoints_list')
                logger.info('Restoring session from checkpoint %s', latest_checkpoint)
                saver.restore(session, latest_checkpoint)
                step = tf_train_step.eval()
                logger.info('Restored session from checkpoint %s at step %d', latest_checkpoint, step)
            else:
                tf.global_variables_initializer().run()
                step = 0
                logger.info('Starting a new session')

            # Ensure that everything was initialized
            assert len(tf.report_uninitialized_variables().eval()) == 0

            try:
                start_time = time.time()
                ############################# Training (with Validation) Cycle ##############################
                ops_accum = (
                    'train',
                    'predicted_ids_list',
                    'predicted_lens',
                    'y_ctc_list',
                    'ctc_len',
                    'ctc_ed',
                    'log_likelihood',
                    'ctc_loss',
                    'alpha_penalty',
                    'cost',
                    'mean_norm_ase',
                    'mean_norm_aae',
                    'beta_mean',
                    'beta_std_dev',
                    'pred_len_ratio',
                    'num_hits',
                    'reg_loss',
                    'scan_len',
                    'bin_len',
                    'global_step'
                )

                ops_log = (
                    'y_s_list',
                    'predicted_ids_list',
                    'alpha',
                    'beta',
                    'image_name_list',
                    'x_s_list',
                    'tb_step_logs'
                )
                if args.doTrain:
                    logger.info('Starting training')
                    accum = Accumulator()
                    while not training_logic.should_stop():
                        step_start_time = time.time()
                        step += 1
                        doLog = training_logic.do_log(step)
                        if not doLog:
                            batch_ops = TFOpNames(ops_accum, None)
                        else:
                            batch_ops = TFOpNames(ops_accum + ops_log, None)

                        batch_ops.run_ops(session, train_ops)
                        assert step == batch_ops.global_step, 'Step(%d) and global-step(%d) fell out of sync ! :((' % (step, batch_ops.global_step)

                        # Accumulate Metrics
                        accum.append(batch_ops)
                        train_time = time.time()-step_start_time
                        bleu = sentence_bleu_scores(hyper, batch_ops.predicted_ids_list, batch_ops.predicted_lens, batch_ops.y_ctc_list, batch_ops.ctc_len)
                        accum.extend({'bleu_scores': bleu})
                        accum.extend({
                            'sq_predicted_ids': squashed_seq_list(hyper, batch_ops.predicted_ids_list, batch_ops.predicted_lens),
                            'sq_y_ctc':         trimmed_seq_list(hyper, batch_ops.y_ctc_list, batch_ops.ctc_len)
                        })
                        accum.append({'train_time': train_time})

                        if doLog:
                            logger.info('Step %d', step)
                            train_time_per100 = np.mean(train_time) * 100. / (hyper.data_reader_B)

                            with dtc.Storer(args, 'training', step) as storer:
                                storer.write('predicted_ids', batch_ops.predicted_ids_list, np.int16)
                                storer.write('y', batch_ops.y_s_list, np.int16)
                                storer.write('alpha', batch_ops.alpha, np.float32, batch_axis=1)
                                storer.write('beta', batch_ops.beta, np.float32, batch_axis=1)
                                storer.write('image_name', batch_ops.image_name_list, dtype=np.unicode_)
                                storer.write('ed', batch_ops.ctc_ed, np.float32)
                                storer.write('bleu', bleu, np.float32)
                                storer.write('bin_len', batch_ops.bin_len, np.float32)
                                storer.write('scan_len', batch_ops.scan_len, np.float32)
                                storer.write('x', batch_ops.x_s_list, np.float32)

                            # per-step metrics
                            tf_sw.add_summary(batch_ops.tb_step_logs, global_step=standardized_step(step))
                            tf_sw.flush()

                            # aggregate metrics
                            agg_bleu2 = dlc.corpus_bleu_score(accum.sq_predicted_ids, accum.sq_y_ctc)
                            tb_agg_logs = session.run(train_ops.tb_agg_logs, feed_dict={
                                train_ops.ph_train_time: train_time_per100,
                                train_ops.ph_bleu_scores: accum.bleu_scores,
                                train_ops.ph_bleu_score2: agg_bleu2,
                                train_ops.ph_ctc_eds: accum.ctc_ed,
                                train_ops.ph_loglosses: accum.log_likelihood,
                                train_ops.ph_ctc_losses: accum.ctc_loss,
                                train_ops.ph_alpha_penalties: accum.alpha_penalty,
                                train_ops.ph_costs: accum.cost,
                                train_ops.ph_mean_norm_ases: accum.mean_norm_ase,
                                train_ops.ph_mean_norm_aaes: accum.mean_norm_aae,
                                train_ops.ph_beta_mean: accum.beta_mean,
                                train_ops.ph_beta_std_dev: accum.beta_std_dev,
                                train_ops.ph_pred_len_ratios: accum.pred_len_ratio,
                                train_ops.ph_num_hits: accum.num_hits,
                                train_ops.ph_reg_losses: accum.reg_loss,
                                train_ops.ph_scan_lens: accum.scan_len
                            })

                            tf_sw.add_summary(tb_agg_logs, global_step=standardized_step(step))
                            tf_sw.flush()

                            doValidate, num_validation_batches, do_save = training_logic.do_validate(step, (agg_bleu2 if (args.valid_epochs <= 0) else None))
                            if do_save:
                                saver.save(session, args.logdir + '/snapshot', global_step=step,
                                           latest_filename='checkpoints_list')
                            if doValidate:
                                if hyper.build_scanning_RNN:
                                    accuracy_res = evaluate_scanning_RNN(args, hyper, session, eval_ops, ops_accum,
                                                                         ops_log, step, num_validation_batches, tf_sw,
                                                                         training_logic)
                                else:
                                    accuracy_res = evaluate(
                                        session,
                                        dlc.Properties({'eval_ops':eval_ops}),
                                        dlc.Properties({'train_it':train_it, 'eval_it':eval_it}),
                                        hyper,
                                        args,
                                        step,
                                        num_validation_batches,
                                        tf_sw,
                                        training_logic)
                                logger.info('Time for %d steps, elapsed = %f, training-time-per-100 = %f, validation-time-per-100 = %f'%(
                                    step,
                                    time.time()-start_time,
                                    train_time_per100,
                                    accuracy_res.eval_time_per100))
                            else:
                                logger.info('Time for %d steps, elapsed = %f, training-time-per-100 = %f'%(
                                    step,
                                    time.time()-start_time,
                                    train_time_per100))

                            # Reset Metrics
                            accum.reset()

                ############################# Validation/Testing Only ##############################
                elif args.doValidate or args.doTest:
                    logger.info('Starting %s Cycle'%('Validation' if args.doValidate else 'Testing',))
                    if hyper.build_scanning_RNN:
                        evaluate_scanning_RNN(args, hyper, session, eval_ops, ops_accum,
                                              ops_log, step,
                                              eval_it.epoch_size,
                                              tf_sw,
                                              training_logic)
                    else:
                        evaluate(session,
                                 dlc.Properties({'eval_ops': eval_ops}),
                                 dlc.Properties({'train_it': train_it, 'eval_it': eval_it}),
                                 hyper,
                                 args,
                                 step,
                                 eval_it.epoch_size,
                                 tf_sw,
                                 training_logic)


            except tf.errors.OutOfRangeError, StopIteration:
                logger.info('Done training -- epoch limit reached')
            except Exception as e:
                logger.info( '***************** Exiting with exception: *****************\n%s'%e)
                coord.request_stop(e)
            finally: