Ejemplo n.º 1
0
 def CRNN_init(self):
     model = crnn.CRNN(32, 1, 37, 256)
     model = model.to(self.device)
     model_path = self.config.TRAIN.VAL.crnn_pretrained
     print('loading pretrained crnn model from %s' % model_path)
     model.load_state_dict(torch.load(model_path))
     return model
Ejemplo n.º 2
0
    def load(self, crnn_path):

        # load CRNN
        self.crnn = crnn.CRNN(self.IMGH, self.nc, self.nclass,
                              nh=256).to(device)
        self.crnn.load_state_dict(torch.load(crnn_path, map_location=device))

        # remember to set to test mode (otherwise some layers might behave differently)
        self.crnn.eval()
    plt.show()

if __name__ == '__main__':

    data_dir = 'data/3d/test/full_data'
    lr = 0.00001
    channels = 10
    vector_dim = 500
    rnn_hidden_size = 500
    rnn_num_layers = 2
    device = torch.device('cpu')

    #checkpoint_dir = 'gcloud/checkpoint_500vector'
    checkpoint_dir = 'gcloud/checkpoint_3months'
    restore_file = 'best'
    model = net.CRNN(channels, vector_dim, rnn_hidden_size, rnn_num_layers)
    model = model.to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    predictions2 = np.load('3DCNN/3m_test_predictions.npy')
    #predictions2 = np.load('3DCNN/1m_test_predictions.npy')



    inputs, labels = create_dataset(12*10)
    load_model(checkpoint_dir, restore_file)
    predictions = evaluate(inputs, labels, model)
    plot_predictions(predictions.numpy(), labels.numpy(), predictions2)

    #plt.savefig('10 years validation 1m ahead')
Ejemplo n.º 4
0
    # print and save hyperparameters
    logging.info("learning rate: " + str(lr))
    logging.info("epochs: " + str(epochs))
    logging.info("batch size: " + str(batch_size))
    logging.info("transfer learning: " + str(transfer_learning))
    logging.info("Variables: " + str(variables))

    #Define the model, dataset
    if model_name == 'cnn':
        model = net.CNN(channels)
        data_dir = 'data/2d'
        logging.info("model: CNN")
    if model_name == 'crnn':
        model = net.CRNN(len(variables),
                         channels,
                         vector_dim,
                         rnn_hidden_size,
                         rnn_num_layers,
                         dropout=dropout)
        data_dir = 'data/3d_multivar'
        #data_dir = '/Volumes/matiascastilloHD/CLIMATEAI/3d_ensemble2_6m'
        logging.info("model: CRNN")
        logging.info("data: CNRM-MPI")
        logging.info("encoding dimesion: " + str(vector_dim))
        logging.info("RNN layers: " + str(rnn_num_layers))
        logging.info("RNN hidden units: " + str(rnn_hidden_size))

    # initialize model weights
    model.apply(net.initialize_weights)

    # reload weights from restore_file if specified
    if restore_file is not None:
def main():

    data_dir = 'data/3d_real/all_6months'
    old_checkpoint_dir = 'experiments/6month_GCM_training/checkpoint_6m_drop0.1'
    results_dir = 'crossval_results_6m'
    labels_dir = 'labels6m.csv'

    all_inputs, all_labels = load_data(data_dir, labels_dir)
    print('Number of inputs: ' + str(len(all_inputs)))
    print('Number of labels: ' + str(len(all_labels)))

    # check if the inputs and labels are multiples of 12 (complete years)
    if len(all_inputs) % 12 != 0 or len(all_labels) % 12 != 0:
        print('data must be complete years')
        return

    # define model parameters:
    channels = 10
    vector_dim = 500
    rnn_hidden_size = 500
    rnn_num_layers = 2
    dropout = 0.1
    lrate = 0.00001
    lr_step_size = 10
    lr_factor = 0.2
    epochs = 50
    restore_file = 'best'
    dtype = torch.float32

    # use GPU if available
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print('Im using GPU')
    else:
        device = torch.device('cpu')

    # Set the random seed for reproducible experiments
    torch.manual_seed(230)
    if device == "cuda:0": torch.cuda.manual_seed(230)

    # create model and loss instance
    model = net.CRNN(channels,
                     vector_dim,
                     rnn_hidden_size,
                     rnn_num_layers,
                     dropout=dropout)
    loss_fn = net.loss_fn

    # Set the logger, will be saved in the results folder
    if not os.path.exists(results_dir):
        os.mkdir(results_dir)
    utils.set_logger(os.path.join(results_dir, 'train.log'))

    logging.info("learning rate: " + str(lrate))
    logging.info("learning rate step size: " + str(lr_step_size))
    logging.info("learning rate factor: " + str(lr_factor))
    logging.info("epochs: " + str(epochs))
    '''
    Start cross-validation
    '''

    counter = 0

    # iterate for year after year
    for i in range(int(len(all_inputs) / 12)):

        # create current dataloaders
        train_dl, val_dl = create_dataloaders(all_inputs, all_labels, i)

        # initialize model and optimizer for every new training set
        model.apply(net.initialize_weights)
        params = model.parameters()
        optimizer = optim.Adam(params, lr=lrate, betas=(0.9, 0.999))
        optim.lr_scheduler.StepLR(optimizer,
                                  lr_step_size,
                                  gamma=lr_factor,
                                  last_epoch=-1)

        # reload weights from restore_file if specified
        if restore_file is not None:
            load_parameters(old_checkpoint_dir, restore_file, model)

        logging.info('Start training for year: ' + str(i + 1))

        for epoch in range(epochs):

            if epoch == 0:
                train_MSE, counter = evaluate(counter, 'train', i,
                                              epoch, results_dir, model,
                                              nn.MSELoss(), train_dl, device,
                                              dtype)
                val_MSE, counter = evaluate(counter, 'val', i,
                                            epoch, results_dir, model,
                                            nn.MSELoss(), val_dl, device,
                                            dtype)
                logging.info("- Initial Train average RMSE loss: " +
                             str(np.sqrt(train_MSE)))
                logging.info("- Initial Validation average RMSE loss: " +
                             str(np.sqrt(val_MSE)))

            print("Epoch {}/{}".format(epoch + 1, epochs))
            train(model, optimizer, loss_fn, train_dl, device, dtype)

            if i < 1:
                logging.info("Epoch {}/{}".format(epoch + 1, epochs))
                # Evaluate MSE for one epoch on train and validation set for the first training
                train_MSE, counter = evaluate(counter, 'train', i,
                                              epoch + 1, results_dir, model,
                                              nn.MSELoss(), train_dl, device,
                                              dtype)
                val_MSE, counter = evaluate(counter, 'train', i,
                                            epoch + 1, results_dir, model,
                                            nn.MSELoss(), val_dl, device,
                                            dtype)
                train_L1, counter = evaluate(counter, 'train', i,
                                             epoch + 1, results_dir, model,
                                             nn.L1Loss(), train_dl, device,
                                             dtype)
                val_L1, counter = evaluate(counter, 'train', i,
                                           epoch + 1, results_dir, model,
                                           nn.L1Loss(), val_dl, device, dtype)
                logging.info("- Train average RMSE loss: " +
                             str(np.sqrt(train_MSE)))
                logging.info("- Validation average RMSE loss: " +
                             str(np.sqrt(val_MSE)))
                logging.info("- Train average L1 loss: " + str(train_L1))
                logging.info("- Validation average L1 loss: " + str(val_L1))

            if epoch % 10 == 0:
                logging.info("Epoch {}/{}".format(epoch + 1, epochs))
                train_MSE, counter = evaluate(counter, 'train', i,
                                              epoch + 1, results_dir, model,
                                              nn.MSELoss(), train_dl, device,
                                              dtype)
                val_MSE, counter = evaluate(counter, 'val', i,
                                            epoch + 1, results_dir, model,
                                            nn.MSELoss(), val_dl, device,
                                            dtype)
                train_L1, counter = evaluate(counter, 'train', i,
                                             epoch + 1, results_dir, model,
                                             nn.L1Loss(), train_dl, device,
                                             dtype)
                val_L1, counter = evaluate(counter, 'train', i,
                                           epoch + 1, results_dir, model,
                                           nn.L1Loss(), val_dl, device, dtype)
                logging.info("- Train average RMSE loss: " +
                             str(np.sqrt(train_MSE)))
                logging.info("- Validation average RMSE loss: " +
                             str(np.sqrt(val_MSE)))
                logging.info("- Train average L1 loss: " + str(train_L1))
                logging.info("- Validation average L1 loss: " + str(val_L1))

        # Evaluate and save MSE and L1  at the end of each training
        train_MSE, counter = evaluate(counter, 'train', i,
                                      epoch + 1, results_dir, model,
                                      nn.MSELoss(), train_dl, device, dtype)
        val_MSE, counter = evaluate(counter, 'val', i, epoch + 1, results_dir,
                                    model, nn.MSELoss(), val_dl, device, dtype)
        train_L1, counter = evaluate(counter, 'train',
                                     i, epoch + 1, results_dir, model,
                                     nn.L1Loss(), train_dl, device, dtype)
        val_L1, counter = evaluate(counter, 'train', i, epoch + 1, results_dir,
                                   model, nn.L1Loss(), val_dl, device, dtype)

        save_crossval(i, train_MSE, val_MSE, train_L1, val_L1, results_dir)
Ejemplo n.º 6
0
def run():
    #### argument parsing ####
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSize',
                        type=int,
                        default=64,
                        help='input batch size')
    parser.add_argument('--epoch',
                        type=int,
                        default=30,
                        help='training epochs')
    parser.add_argument('--dataPath',
                        required=True,
                        help='path to training dataset')
    parser.add_argument('--savePath',
                        required=True,
                        help='path to save trained weights')
    parser.add_argument(
        '--preTrainedPath',
        type=str,
        default=None,
        help='path to pre-trained weights (incremental learning)')
    parser.add_argument('--seed',
                        type=int,
                        default=8888,
                        help='reproduce experiement')
    parser.add_argument('--worker',
                        type=int,
                        default=0,
                        help='number of cores for data loading')
    # parser.add_argument('--imgW', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-1)
    parser.add_argument('--maxLength',
                        type=int,
                        default=9,
                        help='maximum license plate character length in data')
    opt = parser.parse_args()
    print(opt)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    #### set up constants and experiment settings ####
    IMGH = 32

    ## Feature: Char Resizing ##
    # 1. IMGW (const): Uncomment formula
    # 2. LPDataSet __getitem__ (class method): Uncomment PIL resize
    # 3. train_transformer (torchvision): Uncomment transform.Resize

    # Note:crnn output length = img_width / 4 + 1
    # Assumption: 4 cuts per character
    # Calculation: 4 * maxCharLen = img_width / 4 + 1
    #             => img_width = 16 * maxCharLen - 4
    # IMGW = opt.maxLength * 16 - 4
    IMGW = 100

    cudnn.benchmark = True

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)

    if not os.path.exists(opt.savePath):
        os.makedirs(opt.savePath)

    #### data preparation & loading ####
    train_transformer = transforms.Compose([
        transforms.Grayscale(),
        #     transforms.RandomApply([Invert()], p=0.3), # taxi license plate
        transforms.Resize((IMGH, IMGW)),
        transforms.ToTensor()
    ])  # transform it into a torch tensor

    n = range(len(os.listdir(opt.dataPath)))
    train_idx, val_idx = train_test_split(n,
                                          train_size=0.8,
                                          test_size=0.2,
                                          random_state=opt.seed)

    # train data
    print("Checkpoint: Loading data")
    train_loader = DataLoader(LPDataset(opt.dataPath, train_idx,
                                        train_transformer),
                              batch_size=opt.batchSize,
                              num_workers=opt.worker,
                              shuffle=True,
                              pin_memory=True)
    print("Checkpoint: Data loaded")

    # validation data
    val_set = LPDataset(opt.dataPath, val_idx, train_transformer)

    #### setup crnn model hyperparameters ####
    classes = string.ascii_uppercase + string.digits
    nclass = len(classes) + 1
    nc = 1  # number of channels 1=grayscale

    # CRNN(imgH, nc, nclass, num_hidden(LSTM))
    crnn = model.CRNN(IMGH, nc, nclass, 256).to(device)
    print("Checkpoint: Model loaded")

    if torch.cuda.device_count() > 1:
        print("Running parallel on", torch.cuda.device_count(), "GPUs..")
        crnn = torch.nn.DataParallel(crnn, range(1))

    if opt.preTrainedPath is not None:
        crnn.load_state_dict(
            torch.load(opt.preTrainedPath, map_location=device))
    else:
        crnn.apply(weights_init)

    #### image and text (convert to tensor) ####
    image = torch.FloatTensor(opt.batchSize, 1, IMGH, IMGH).to(device)
    text = torch.IntTensor(opt.batchSize * 5)
    length = torch.IntTensor(opt.batchSize)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    #### decoder, loss function, batch loss ####
    converter = utils.strLabelConverter(classes)
    loss_avg = utils.averager()
    criterion = nn.CTCLoss().to(device)

    #### learning rate, lr scheduler, lr optimiser ####
    LR = opt.lr
    optimizer = optim.Adadelta(crnn.parameters(), lr=LR)
    # T_max = len(train_loader) * EPOCH
    # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=LR/10)

    #### training begins ####

    # 25000 * 0.8 (# of data) // 64 (bs) ~= 310 (iterations)
    save_iter = len(os.listdir(opt.dataPath)) * 0.8 // opt.batchSize
    PRINT_ITER = save_iter

    print("Checkpoint: Start training")
    print("You are training on", device)
    for epoch in range(opt.epoch):
        train_iter = iter(train_loader)
        i = 0

        while i < len(train_loader):

            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = trainBatch(crnn, criterion, optimizer, converter,
                              train_iter, image, text, length)
            loss_avg.add(cost)

            i += 1
            if i % PRINT_ITER == 0:
                # print training loss and validation loss
                print('[%d/%d][%d/%d] Train Loss: %f  Validation Loss: %f' %
                      (epoch + 1, opt.epoch, i + 1, len(train_loader),
                       loss_avg.val(),
                       validation(crnn, val_set, opt.batchSize, opt.worker,
                                  criterion, converter, image, text, length)))
                loss_avg.reset()

            if i % save_iter == 0:

                try:
                    state_dict = crnn.module.state_dict()
                except AttributeError:
                    state_dict = crnn.state_dict()

                torch.save(
                    state_dict,
                    os.path.join(opt.savePath,
                                 'netCRNN_{}_{}.pth'.format(epoch + 1, i + 1)))
Ejemplo n.º 7
0
def main(_):

    assert FLAGS.file_pattern, "--file_pattern is required"
    assert FLAGS.train_checkpoints, "--train_checkpoints is required"
    assert FLAGS.summaries_dir, "--summaries_dir is required"

    vocab = Vocabulary()

    model_config = configuration.ModelConfig()

    training_config = configuration.TrainingConfig()
    print(FLAGS.learning_rate)
    training_config.initial_learning_rate = FLAGS.learning_rate

    sequence_length = model_config.sequence_length
    batch_size = FLAGS.batch_size

    summaries_dir = FLAGS.summaries_dir
    if not tf.gfile.IsDirectory(summaries_dir):
        tf.logging.info("Creating training directory: %s", summaries_dir)
        tf.gfile.MakeDirs(summaries_dir)

    train_checkpoints = FLAGS.train_checkpoints
    if not tf.gfile.IsDirectory(train_checkpoints):
        tf.logging.info("Creating training directory: %s", train_checkpoints)
        tf.gfile.MakeDirs(train_checkpoints)

    # 数据队列初始化
    input_queue = DataReader(FLAGS.dataset_dir,
                             FLAGS.file_pattern,
                             model_config,
                             batch_size=batch_size)

    g = tf.Graph()
    with g.as_default():
        # 数据队列
        with tf.name_scope(None, 'input_queue'):
            input_images, input_labels = input_queue.read()

        # 模型建立
        model = crnn.CRNN(256, model_config.num_classes, 'train')
        logits = model.build(input_images)

        with tf.name_scope(None, 'loss'):
            loss = tf.reduce_mean(
                tf.nn.ctc_loss(labels=input_labels,
                               inputs=logits,
                               sequence_length=sequence_length *
                               tf.ones(batch_size, dtype=tf.int32)),
                name='compute_loss',
            )
            tf.losses.add_loss(loss)
            total_loss = tf.losses.get_total_loss(False)

        with tf.name_scope(None, 'decoder'):
            decoded, _ = tf.nn.ctc_beam_search_decoder(
                logits,
                sequence_length * tf.ones(batch_size, dtype=tf.int32),
                merge_repeated=False,
            )
            with tf.name_scope(None, 'acurracy'):
                sequence_dist = tf.reduce_mean(
                    tf.edit_distance(tf.cast(decoded[0], tf.int32),
                                     input_labels),
                    name='seq_dist',
                )
            preds = tf.sparse_tensor_to_dense(decoded[0], name='prediction')
            gt_labels = tf.sparse_tensor_to_dense(input_labels,
                                                  name='Ground_Truth')

        # print(len(slim.get_model_variables()))
        # print('>>>>>>>>>>>>>>>>>>>>>>>>>>>')
        # print(len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
        # sys.exit()
        global_step = tf.Variable(initial_value=0,
                                  name="global_step",
                                  trainable=False,
                                  collections=[
                                      tf.GraphKeys.GLOBAL_STEP,
                                      tf.GraphKeys.GLOBAL_VARIABLES
                                  ])

        start_learning_rate = training_config.initial_learning_rate
        learning_rate = tf.train.exponential_decay(
            start_learning_rate,
            global_step,
            decay_steps=training_config.learning_decay_steps,
            decay_rate=training_config.learning_rate_decay_factor,
            staircase=True,
        )

        # summary
        # Add summaries for variables.
        for variable in slim.get_model_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.scalar(name='Seq_Dist', tensor=sequence_dist)
        tf.summary.scalar(name='global_step', tensor=global_step)
        tf.summary.scalar(name='learning_rate', tensor=learning_rate)
        tf.summary.scalar(name='total_loss', tensor=total_loss)

        # global/secs hook
        globalhook = tf.train.StepCounterHook(
            every_n_steps=FLAGS.log_every_n_steps, )
        # 保存chekpoints的hook
        # saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
        # saverhook = tf.train.CheckpointSaverHook(
        #     checkpoint_dir=FLAGS.train_checkpoints,
        #     save_steps=2000,
        #     saver=saver,
        # )
        # #保存summaries的hook
        # merge_summary_op = tf.summary.merge_all()
        # summaryhook = tf.train.SummarySaverHook(
        #     save_steps=200,
        #     output_dir=FLAGS.summaries_dir,
        #     summary_op=merge_summary_op,
        # )
        # 训练时需要logging的hook
        tensors_print = {
            'global_step': global_step,
            'loss': loss,
            'Seq_Dist': sequence_dist,
            # 'accurays':accurays,
        }
        loghook = tf.train.LoggingTensorHook(
            tensors=tensors_print,
            every_n_iter=FLAGS.log_every_n_steps,
        )
        # 停止hook
        stophook = tf.train.StopAtStepHook(last_step=FLAGS.number_of_steps)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        session_config = tf.ConfigProto(log_device_placement=False,
                                        gpu_options=gpu_options)

        # extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # with tf.control_dependencies(extra_update_ops):
        #     optimizer = tf.train.AdadeltaOptimizer(
        #         learning_rate=learning_rate).minimize(loss=total_loss, global_step=global_step)

        optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate)
        train_op = tf.contrib.training.create_train_op(total_loss=total_loss,
                                                       optimizer=optimizer,
                                                       global_step=global_step)
        # train_op = tf.group([optimizer, total_loss, sequence_dist])
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_checkpoints,
                hooks=[globalhook, loghook, stophook],
                save_checkpoint_secs=180,
                save_summaries_steps=100,
                config=session_config) as sess:
            while not sess.should_stop():
                oloss, opreds, ogt_labels = sess.run(
                    [train_op, preds, gt_labels])
                accuray = compute_acuracy(opreds, ogt_labels)
                print("accuracy: %9f" % (accuray))
Ejemplo n.º 8
0
def main(_):

    assert FLAGS.file_pattern, "--file_pattern is required"
    assert FLAGS.train_checkpoints, "--train_checkpoints is required"
    assert FLAGS.summaries_dir, "--summaries_dir is required"

    vocab = Vocabulary()

    model_config = configuration.ModelConfig()

    training_config = configuration.TrainingConfig()
    print(FLAGS.learning_rate)
    training_config.initial_learning_rate = FLAGS.learning_rate

    sequence_length = model_config.sequence_length
    batch_size = FLAGS.batch_size

    summaries_dir = FLAGS.summaries_dir
    if not tf.gfile.IsDirectory(summaries_dir):
        tf.logging.info("Creating training directory: %s", summaries_dir)
        tf.gfile.MakeDirs(summaries_dir)

    train_checkpoints = FLAGS.train_checkpoints
    if not tf.gfile.IsDirectory(train_checkpoints):
        tf.logging.info("Creating training directory: %s", train_checkpoints)
        tf.gfile.MakeDirs(train_checkpoints)

    # 数据队列初始化
    input_queue = DataReader(FLAGS.dataset_dir,
                             FLAGS.file_pattern,
                             model_config,
                             batch_size=batch_size)

    g = tf.Graph()
    with g.as_default():
        # 数据队列
        with tf.name_scope(None, 'input_queue'):
            input_images, input_labels = input_queue.read()

        # 模型建立
        model = crnn.CRNN(256, model_config.num_classes, 'train')
        logits = model.build(input_images)

        with tf.name_scope(None, 'loss'):

            loss = tf.reduce_mean(
                tf.nn.ctc_loss(labels=input_labels,
                               inputs=logits,
                               sequence_length=sequence_length *
                               tf.ones(batch_size, dtype=tf.int32)),
                name='compute_loss',
            )
            tf.losses.add_loss(loss)
            total_loss = tf.losses.get_total_loss(False)

        with tf.name_scope(None, 'decoder'):
            decoded, _ = tf.nn.ctc_beam_search_decoder(
                logits,
                sequence_length * tf.ones(batch_size, dtype=tf.int32),
                merge_repeated=False,
            )
            with tf.name_scope(None, 'acurracy'):
                sequence_dist = tf.reduce_mean(
                    tf.edit_distance(tf.cast(decoded[0], tf.int32),
                                     input_labels),
                    name='seq_dist',
                )
            preds = tf.sparse_tensor_to_dense(decoded[0], name='prediction')
            gt_labels = tf.sparse_tensor_to_dense(input_labels,
                                                  name='Ground_Truth')

        global_step = tf.Variable(initial_value=0,
                                  name="global_step",
                                  trainable=False,
                                  collections=[
                                      tf.GraphKeys.GLOBAL_STEP,
                                      tf.GraphKeys.GLOBAL_VARIABLES
                                  ])

        # 训练时需要logging的hook
        tensors_print = {
            'global_step': global_step,
            #'loss': loss,
        }
        loghook = tf.train.LoggingTensorHook(
            tensors=tensors_print,
            every_n_iter=FLAGS.log_every_n_steps,
        )
        # 停止hook
        stophook = tf.train.StopAtStepHook(last_step=FLAGS.number_of_steps)

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        session_config = tf.ConfigProto(log_device_placement=False,
                                        gpu_options=gpu_options)

        train_op = tf.assign_add(global_step, tf.constant(1))
        session = tf.train.ChiefSessionCreator(
            config=session_config,
            checkpoint_dir=FLAGS.train_checkpoints,
        )

        labels_shape = input_labels.dense_shape
        with tf.train.MonitoredSession(session, hooks=[loghook,
                                                       stophook]) as sess:

            while not sess.should_stop():
                test_logits, test_images, test_shape, _ = \
                        sess.run([logits, input_images, labels_shape, input_labels])
                if test_logits.shape[
                        1] != FLAGS.batch_size or test_images.shape[
                            0] != FLAGS.batch_size or test_shape[
                                0] != FLAGS.batch_size:
                    print("get it!!!!!")
                test_loss = sess.run([loss])
                sess.run(train_op)
Ejemplo n.º 9
0

gpu_options = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
sess = tf.InteractiveSession(config=config)


#模型建立
vocab = Vocabulary()

with tf.name_scope(None, 'input_image'):
    img_input = tf.placeholder(tf.uint8, shape=(32, 300, 3))
    image = tf.to_float(img_input)
    image = tf.expand_dims(image, 0)

model = crnn.CRNN(256, 37, 'inference')
logit = model.build(image)

# print(logit.get_shape().as_list())
# print(tf.shape(logit)[0])
# sys.exit()

decodes, _ = tf.nn.ctc_beam_search_decoder(inputs=logit,
                                           sequence_length=tf.shape(
                                               logit)[0]*np.ones(1),
                                           merge_repeated=False
                                           )
pred = tf.sparse_tensor_to_dense(decodes[0])


# 模型恢复