예제 #1
0
    def run_eval(self, epoch, batch_size=16):
        val_dataset = DataGen(self.IMG_SIZE, 5, False)
        val_gen = val_dataset.generator(batch_size)

        count = 0
        Y_pred = []
        Y_gt = []
        for X_batch, y_batch in val_gen:

            count += batch_size
            if count > val_dataset.get_dataset_size():
                break

            y_pred = self.model.predict(X_batch)

            y_pred = np.argmax(y_pred, -1)
            y_gt = np.argmax(y_batch, -1)
            Y_pred = np.concatenate([Y_pred, y_pred])
            Y_gt = np.concatenate([Y_gt, y_gt])
        acc = accuracy_score(Y_gt, Y_pred)
        print('Eval Accuracy: %2.2f%%' % acc, '@ Epoch ', epoch)
        if (epoch+1) % 10 == 0:
            print(classification_report(Y_gt, Y_pred))

        with open('checkpoints/'+self.out_name+'/val.txt', 'a+') as xfile:
            xfile.write('Epoch ' + str(epoch) + ':' + str(acc) + '\n')
예제 #2
0
파일: model.py 프로젝트: bywmm/Retinopathy
    def train(self, batch_size, epoches, out_name, mini_data):
        print(mini_data)

        # learning rate schedule
        def step_decay(epoch):
            initial_lrate = 1e-3
            drop = 0.5
            epochs_drop = 7.0
            lrate = initial_lrate * math.pow(
                drop, math.floor((1 + epoch) / epochs_drop))
            return lrate

        train_dataset = DataGen(self.IMG_SIZE, 5, True, mini=mini_data)
        train_gen = train_dataset.generator(batch_size, True)

        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())

        callbacks_list = [
            EvalCallBack(out_name),
            EarlyStopping(monitor='val_loss', mode='min', patience=6),
            TensorBoard(log_dir='logs/' + TIMESTAMP,
                        batch_size=batch_size,
                        update_freq='epoch'),
            LearningRateScheduler(step_decay)
        ]

        self.model.compile(optimizer='adam',
                           loss='categorical_crossentropy',
                           metrics=['categorical_accuracy'])
        self.model.fit_generator(
            generator=train_gen,
            steps_per_epoch=train_dataset.get_dataset_size() // batch_size,
            epochs=epoches,
            callbacks=callbacks_list)
예제 #3
0
    def train(self, train_data, valid_data):
        """
        train
        """
        yolo_config = YoloConfig()

        data_gen = DataGen(yolo_config, train_data, valid_data)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            model = Yolo3(sess, True, yolo_config)

            if BASE_MODEL_URL and os.path.exists(BASE_MODEL_URL):
                LOG.info(
                    f"loading base model, BASE_MODEL_URL={BASE_MODEL_URL}")
                saver = tf.train.Saver()
                latest_ckpt = tf.train.latest_checkpoint(BASE_MODEL_URL)
                LOG.info(f"latest_ckpt={latest_ckpt}")
                saver.restore(sess, latest_ckpt)

            steps_per_epoch = int(
                round(data_gen.train_data_size / data_gen.batch_size))
            total = steps_per_epoch * flags.max_epochs
            with tqdm(desc='Train: ', total=total) as pbar:
                for epoch in range(flags.max_epochs):
                    LOG.info('Epoch %d...' % epoch)
                    for step in range(
                            steps_per_epoch):  # Get a batch and make a step.

                        batch_data = data_gen.next_batch_train(
                        )  # get batch data from Queue
                        if not batch_data:
                            continue

                        batch_loss = model.step(sess, batch_data, True)
                        # pbar.set_description('Train, loss={:.8f}'.format(batch_loss))
                        pbar.set_description(
                            'Train, input_shape=(%d, %d), loss=%.4f' %
                            (batch_data['input_shape'][0],
                             batch_data['input_shape'][1], batch_loss))
                        pbar.update()

                    # LOG.info('validating...')
                    # val_loss = self.validate(sess, model, data_gen, flags.batch_size)
                    # LOG.info('loss of validate data : %.2f' % val_loss)

                    LOG.info("Saving model, global_step: %d" %
                             model.global_step.eval())
                    checkpoint_path = os.path.join(
                        model.model_dir, "yolo3-epoch%03d.ckpt" % (epoch))
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=model.global_step,
                                     write_meta_graph=False)
예제 #4
0
    def train(self, train_data, valid_data=None, **kwargs):
        """
        train
        """
        yolo_config = YoloConfig()

        data_gen = DataGen(yolo_config, train_data.x)

        max_epochs = int(kwargs.get("epochs", flags.max_epochs))
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            model = Yolo3(sess, True, yolo_config)

            if os.path.exists(model.model_dir):
                saver = tf.train.Saver()
                latest_ckpt = tf.train.latest_checkpoint(model.model_dir)
                if latest_ckpt:
                    LOG.info(f"latest_ckpt={latest_ckpt}")
                    saver.restore(sess, latest_ckpt)
            else:
                os.makedirs(model.model_dir)
            steps_per_epoch = int(
                round(data_gen.train_data_size / data_gen.batch_size))
            total = steps_per_epoch * max_epochs
            loss = []
            with tqdm(desc='Train: ', total=total) as pbar:
                for epoch in range(max_epochs):
                    LOG.info('Epoch %d...' % epoch)
                    # Get a batch and make a step.
                    for step in range(steps_per_epoch):

                        batch_data = data_gen.next_batch_train()
                        if not batch_data:
                            continue

                        batch_loss = model.step(sess, batch_data, True)
                        pbar.set_description(
                            'Train, input_shape=(%d, %d), loss=%.4f' %
                            (batch_data['input_shape'][0],
                             batch_data['input_shape'][1], batch_loss))
                        pbar.update()
                        loss.append(batch_loss)
                    LOG.info("Saving model, global_step: %d" %
                             model.global_step.eval())
                    checkpoint_path = os.path.join(
                        model.model_dir, "yolo3-epoch%03d.ckpt" % epoch)
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=model.global_step,
                                     write_meta_graph=False)
            return {"loss": float(np.mean(loss))}
예제 #5
0
    def build_datagen(self, filepath, with_aug=True):

        batch_size = self.config['train']['batch_size']
        input_width = self.config['model']['input_width']
        input_height = self.config['model']['input_height']

        return DataGen(filepath, batch_size, (input_width, input_height),
                       self.preprocess_input, with_aug)
예제 #6
0
파일: model.py 프로젝트: bywmm/Retinopathy
    def eval(self, batch_size, out_name, mini_data=True):
        val_dataset = DataGen(self.IMG_SIZE, 5, False, mini_data)
        val_dataset.generator(batch_size)

        print("val data size: ", val_dataset.get_dataset_size())

        Y_pred = []
        X = []
        Y_gt = []
        count = 0
        for X_batch, y_batch in val_dataset.generator(batch_size):

            count += batch_size
            print("count:", count)

            if count > val_dataset.get_dataset_size():
                break

            y_pred = self.model.predict(X_batch)

            y_pred = np.argmax(y_pred, -1)
            y_gt = np.argmax(y_batch, -1)
            Y_pred = np.concatenate([Y_pred, y_pred])
            Y_gt = np.concatenate([Y_gt, y_gt])
        acc = accuracy_score(Y_gt, Y_pred)
        print('Eval Accuracy: %2.2f%%' % acc)
        # sns.heatmap(confusion_matrix(Y_gt, Y_pred),
        #             annot=True, fmt="d", cbar=False, cmap=plt.cm.Blues, vmax=Y_pred.shape[0] // 16)
        # plt.show()
        np_confusion = confusion_matrix(Y_gt, Y_pred)
        np.save('confusion_' + str(out_name) + '.npy', np_confusion)
예제 #7
0
 def build_datagen(self, filepath, with_aug=True):
    
     batch_size    = self.config['train']['batch_size']
     data_dir      = self.config['train']['data_dir']
     input_width   = self.config['model']['input_width']
     input_height  = self.config['model']['input_height']
     class_num     = self.config['model']['class_num']
     
     return DataGen(filepath,batch_size,class_num,(input_width,input_height),3,
                    self.preprocess_input,with_aug,True,data_dir,'train',None)
예제 #8
0
def main():
    data_root_dir = r'D:\myData\huawei_datetext\train_img'
    data_path = r'D:\myData\huawei_datetext\train_txt.txt'
    lexicon_file = 'date_lexicon.txt'

    gens = DataGen(data_root_dir,
                   data_path,
                   lexicon_file=lexicon_file,
                   mean=[128],
                   channel=1,
                   evaluate=False,
                   valid_target_len=float('inf'))
    batch_size = 1
    count = 2000
    for k in range(8):
        batch_size *= 2
        count = count // 2
        print('batch_size = ', batch_size)
        for i, batch in enumerate(gens.gen(batch_size)):
            if i % count == 0:
                print("get batch index : " + str(i))
예제 #9
0
파일: model.py 프로젝트: bywmm/Retinopathy
    def resume_train(self,
                     batch_size,
                     model_json,
                     model_weights,
                     init_epoch,
                     epochs,
                     out_name,
                     mini_data=True):

        self.load_model(model_json, model_weights)
        self.model.compile(optimizer=Adam(lr=5e-4),
                           loss='categorical_crossentropy',
                           metrics=["categorical_accuracy"])

        train_dataset = DataGen(self.IMG_SIZE,
                                5,
                                is_train=True,
                                mini=mini_data)
        train_gen = train_dataset.generator(batch_size, True)

        model_dir = os.path.dirname(os.path.abspath(model_json))
        print(model_dir, model_json)

        TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now())
        callbacks_list = [
            EvalCallBack(out_name),
            EarlyStopping(monitor='val_loss', mode='min', patience=6),
            TensorBoard(log_dir='logs/' + TIMESTAMP,
                        batch_size=batch_size,
                        update_freq='epoch')
        ]

        self.model.fit_generator(
            generator=train_gen,
            steps_per_epoch=train_dataset.get_dataset_size() // batch_size,
            initial_epoch=init_epoch,
            epochs=epochs,
            callbacks=callbacks_list)
예제 #10
0
                          write_grads=False,
                          write_images=False,
                          embeddings_freq=0,
                          embeddings_layer_names=None,
                          embeddings_metadata=None,
                          embeddings_data=None)
np.random.seed(1996)

indexes = np.arange(train_dataset_info.shape[0])
np.random.shuffle(indexes)

train_indexes = indexes[:int(len(indexes) * 0.8)]
valid_indexes = indexes[int(len(indexes) * 0.8):]

train_datagen = DataGen.create_set(train_dataset_info[train_indexes],
                                   BATCH_SIZE,
                                   SHAPE,
                                   augment=True)
valid_datagen = DataGen.create_set(train_dataset_info[valid_indexes],
                                   BATCH_SIZE,
                                   SHAPE,
                                   augment=False)

history = model.fit_generator(train_datagen,
                              validation_data=next(valid_datagen),
                              steps_per_epoch=STEPS_PER_EPOCH,
                              epochs=EPOCHS,
                              verbose=1,
                              callbacks=[checkpointer, tensorboard])

submit = pd.read_csv('./res/sample_submission.csv')
예제 #11
0
    def test(self, data_path):
        current_step = 0
        num_correct = 0.0
        num_total = 0.0

        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=1,
                        max_width=self.max_original_width)
        for batch in s_gen.gen(1):
            current_step += 1
            # Get a batch (one image) and make a step.
            start_time = time.time()
            result = self.step(batch, self.forward_only, 0.0)
            curr_step_time = (time.time() - start_time)

            num_total += 1

            output = result['prediction']
            ground = batch['labels'][0]
            comment = batch['comments'][0]
            if sys.version_info >= (3, ):
                output = output.decode('iso-8859-1')
                ground = ground.decode('iso-8859-1')
                comment = comment.decode('iso-8859-1')

            probability = result['probability']

            if self.use_distance:
                incorrect = distance.levenshtein(output, ground)
                if not ground:
                    if not output:
                        incorrect = 0
                    else:
                        incorrect = 1
                else:
                    incorrect = float(incorrect) / len(ground)
                incorrect = min(1, incorrect)
            else:
                incorrect = 0 if output == ground else 1

            num_correct += 1. - incorrect

            if self.visualize:
                # Attention visualization.
                threshold = 0.5
                normalize = True
                binarize = True
                attns_list = [[a.tolist() for a in step_attn]
                              for step_attn in result['attentions']]
                attns = np.array(attns_list).transpose([1, 0, 2])
                visualize_attention(batch['data'],
                                    'out',
                                    attns,
                                    output,
                                    self.max_width,
                                    DataGen.IMAGE_HEIGHT,
                                    threshold=threshold,
                                    normalize=normalize,
                                    binarize=binarize,
                                    ground=ground,
                                    flag=None)

            step_accuracy = "{:>4.0%}".format(1. - incorrect)
            if incorrect:
                correctness = step_accuracy + " ({} vs {}) {}".format(
                    output, ground, comment)
            else:
                correctness = step_accuracy + " (" + ground + ")"

            logging.info(
                'Step {:.0f} ({:.3f}s). '
                'Accuracy: {:6.2%}, '
                'loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'.
                format(
                    current_step, curr_step_time, num_correct / num_total,
                    result['loss'],
                    math.exp(result['loss']) if result['loss'] < 300 else
                    float('inf'), probability, correctness))
        return num_correct / num_total
예제 #12
0
model.add(
    Convolution2D(64,
                  3,
                  3,
                  subsample=(1, 1),
                  activation='relu',
                  border_mode='same'))
model.add(Flatten())
model.add(Dropout(args.dense_drop))
model.add(Dense(100))
model.add(Dropout(args.dense_drop))
model.add(Dense(50))
model.add(Dense(10))
model.add(Dense(1))

gen = DataGen()

model.compile(optimizer=Adam(lr=.0001), loss='mse')
history = model.fit_generator(gen.next_train(),
                              samples_per_epoch=gen.samples_per_epoch(),
                              nb_epoch=5,
                              validation_data=gen.next_valid(),
                              nb_val_samples=gen._validation.shape[0])

plt.plot(history.history['val_loss'])
plt.title('dropout')
plt.ylabel('validation loss')
plt.xlabel('epoch')
plt.xticks([0, 1, 2, 3, 4])
plt.savefig("dropout.jpg")
예제 #13
0
파일: train.py 프로젝트: jfyang-cn/tf-learn
def main(args):
    
    path_dataset = args.dataset # '/home/philyang/drone/data/data512'
    traintxt = args.traintxt # '/home/philyang/drone/data/data512/train.txt'
    trainGen = DataGen(filepath=traintxt, path_dataset=path_dataset)
#     trainGen = DataGenCifar10(batch_size=4, class_num=10, dim=(256,256), n_channels=3)
    
    valtxt = args.valtxt    
    valGen = DataGen(filepath=valtxt, path_dataset=path_dataset)
#     valGen = DataGenCifar10(batch_size=4, class_num=10, dim=(256,256), n_channels=3)
    
    # define checkpoint
    dataset_name = trainGen.name()
    dirname = 'ckpt-' + dataset_name
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    filepath = os.path.join(dirname, 'weights-%s-{epoch:02d}-{loss:.2f}.hdf5' %(timestr))
    checkpoint = ModelCheckpoint(filepath=filepath, 
                             monitor='val_loss',    # acc outperforms loss
                             verbose=1, 
                             save_best_only=False, 
                             save_weights_only=True, 
                             period=5)

    # define logs for tensorboard
    tensorboard = TensorBoard(log_dir='logs', histogram_freq=0)

    wgtdir = 'weights'
    if not os.path.exists(wgtdir):
        os.makedirs(wgtdir)    

    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
        
    # Open a strategy scope.
#     with strategy.scope():
    model = autoencoder()
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), loss=sparse_crossentropy)
    model.summary()

    start_epoch = 0

    # Load weight of unfinish training model(optional)
    if args.weights is not None and args.start_epoch is not None:
        weights_path = args.weights
        start_epoch = int(args.start_epoch)
        model.load_weights(weights_path)

    model.fit_generator(
        generator=trainGen,
        steps_per_epoch=len(trainGen),
        validation_data=valGen,
        validation_steps=len(valGen),
        initial_epoch=start_epoch, 
        epochs=1000, 
        callbacks=[checkpoint,tensorboard], 
        use_multiprocessing=False, 
        verbose=1,
        workers=1,
        max_queue_size=10)

    model.save('./model.h5')
예제 #14
0
    def test(self, data_path):
        current_step = 0
        num_correct = 0.0
        num_total = 0.0

        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=1,
                        max_width=self.max_original_width)
        for batch in s_gen.gen(1):
            current_step += 1
            # Get a batch (one image) and make a step.
            start_time = time.time()
            result = self.step(batch, self.forward_only)
            curr_step_time = (time.time() - start_time)

            if self.visualize:
                step_attns = np.array([[a.tolist() for a in step_attn]
                                       for step_attn in result['attentions']
                                       ]).transpose([1, 0, 2])

            num_total += 1

            output = result['prediction']
            ground = batch['labels'][0]
            comment = batch['comments'][0]
            if sys.version_info >= (3, ):
                output = output.decode('iso-8859-1')
                ground = ground.decode('iso-8859-1')
                comment = comment.decode('iso-8859-1')

            probability = result['probability']

            if self.use_distance:
                incorrect = distance.levenshtein(output, ground)
                if len(ground) == 0:
                    if len(output) == 0:
                        incorrect = 0
                    else:
                        incorrect = 1
                else:
                    incorrect = float(incorrect) / len(ground)
                incorrect = min(1, incorrect)
            else:
                incorrect = 0 if output == ground else 1

            num_correct += 1. - incorrect

            if self.visualize:
                self.visualize_attention(batch['data'], step_attns[0], output,
                                         ground, incorrect)

            step_accuracy = "{:>4.0%}".format(1. - incorrect)
            correctness = step_accuracy + (" ({} vs {}) {}".format(
                output, ground, comment) if incorrect else " (" + ground + ")")

            log = open('log.txt', 'a')
            log.write(
                '\nStep {:.0f} ({:.3f}s). Accuracy: {:6.2%}, loss: {:f}, perplexity: {:0<7.6}, probability: {:6.2%} {}'
                .format(
                    current_step, curr_step_time, num_correct / num_total,
                    result['loss'],
                    math.exp(result['loss']) if result['loss'] < 300 else
                    float('inf'), probability, correctness))
            log.close()
import matplotlib.pyplot as plt
from data_gen import DataGen
import numpy as np
from sklearn.model_selection import train_test_split

#db_path = r'G:\mix5k.h5'
db_path = r'E:\images_uint8.h5'

train_split = .75
ishape = (16, 16)
pshape = (64, 64)
strides = (32, 32)
bsize = 64
num_epochs = 30

dg = DataGen(db_path, pshape=pshape, strides=strides, ds_method='cubic')
with dg.load_db() as db:
    all_data = dg.get_datasets()
    train, test = train_test_split(all_data,
                                   train_size=train_split,
                                   shuffle=True,
                                   random_state=7)
    train_patches = dg.get_patch_list(train)
    test_patches = dg.get_patch_list(test, pshape=pshape, strides=strides)
    print('Loading Images ....')
    dg.load_images()
    train_gen = dg.patch_gen(train_patches, bsize)
    test_gen = dg.patch_gen(test_patches, bsize)
    val_gen = dg.patch_gen(test_patches, bsize)
    len_train = len(train_patches) // bsize
    len_test = len(test_patches) // bsize
예제 #16
0
def generate_exp_data():
    for dataset in datasets:
        for s_length in frames:
            gen = DataGen("{}/".format(dataset), fpv=s_length)
            gen.generate_data()
예제 #17
0
    def train(self, data_path, num_epoch):
        logging.info('num_epoch: %d' % num_epoch)
        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=num_epoch,
                        max_width=self.max_original_width)
        step_time = 0.0
        loss = 0.0
        current_step = 0
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)

        log = open('log.txt', 'a')
        log.write('Starting the training process.')
        log.close()
        for batch in s_gen.gen(self.batch_size):

            current_step += 1

            start_time = time.time()
            result = self.step(batch, self.forward_only)
            loss += result['loss'] / self.steps_per_checkpoint
            curr_step_time = (time.time() - start_time)
            step_time += curr_step_time / self.steps_per_checkpoint

            # num_correct = 0

            # step_outputs = result['prediction']
            # grounds = batch['labels']
            # for output, ground in zip(step_outputs, grounds):
            #     if self.use_distance:
            #         incorrect = distance.levenshtein(output, ground)
            #         incorrect = float(incorrect) / len(ground)
            #         incorrect = min(1.0, incorrect)
            #     else:
            #         incorrect = 0 if output == ground else 1
            #     num_correct += 1. - incorrect

            writer.add_summary(result['summaries'], current_step)

            # precision = num_correct / len(batch['labels'])
            step_perplexity = math.exp(
                result['loss']) if result['loss'] < 300 else float('inf')

            # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.'
            #              % (current_step, curr_step_time, precision*100, result['loss'], step_perplexity))

            log = open('log.txt', 'a')
            log.write('\nStep %i: %.3fs, loss: %f, perplexity: %f.' %
                      (current_step, curr_step_time, result['loss'],
                       step_perplexity))

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % self.steps_per_checkpoint == 0:
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                # Print statistics for the previous epoch.
                log.write(
                    "\nGlobal step %d. Time: %.3f, loss: %f, perplexity: %.2f."
                    % (self.sess.run(
                        self.global_step), step_time, loss, perplexity))
                # Save checkpoint and reset timer and loss.
                log.write("\nSaving the model at step %d." % current_step)
                self.saver_all.save(self.sess,
                                    self.checkpoint_path,
                                    global_step=self.global_step)
                step_time, loss = 0.0, 0.0

            log.close()

        # Print statistics for the previous epoch.
        log = open('log.txt', 'a')
        perplexity = math.exp(loss) if loss < 300 else float('inf')
        log.write(
            "Global step %d. Time: %.3f, loss: %f, perplexity: %.2f." %
            (self.sess.run(self.global_step), step_time, loss, perplexity))
        # Save checkpoint and reset timer and loss.
        log.write("Finishing the training and saving the model at step %d." %
                  current_step)
        log.close()
        self.saver_all.save(self.sess,
                            self.checkpoint_path,
                            global_step=self.global_step)
예제 #18
0
import numpy as np
from sklearn.model_selection import train_test_split
from skimage.measure import compare_psnr, compare_ssim
from skimage.transform import resize

db_path = r'E:\images.h5'
#db_path = r'G:\mix5k.h5'

train_split = .75
ishape = (16, 16)
pshape = (32, 32)
strides = (16, 16)
bsize = 64
num_epochs = 30

dg = DataGen(db_path, pshape=pshape, strides=strides)
with dg.load_db() as db:
    all_data = dg.get_datasets()
    train, test = train_test_split(all_data,
                                   train_size=train_split,
                                   shuffle=True,
                                   random_state=7)
    train_patches = dg.get_patch_list(train)
    test_patches = dg.get_patch_list(test, pshape=(64, 64), strides=(32, 32))
    print('Loading Images ....')
    dg.load_images()
    train_gen = dg.patch_gen(train_patches, bsize)
    test_gen = dg.patch_gen(test_patches, bsize)
    len_train = len(train_patches) // bsize
    len_test = len(test_patches) // bsize
예제 #19
0
    def train(self, data_path, num_epoch, learning_rate):
        logging.info('num_epoch: %d', num_epoch)
        s_gen = DataGen(data_path,
                        self.buckets,
                        epochs=num_epoch,
                        max_width=self.max_original_width)
        step_time = 0.0
        loss = 0.0
        current_step = 0
        skipped_counter = 0
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)

        logging.info('Starting the training process.')
        for batch in s_gen.gen(self.batch_size):

            current_step += 1

            start_time = time.time()
            # result = self.step(batch, self.forward_only)
            result = None
            try:
                result = self.step(batch, self.forward_only, learning_rate)
            except Exception as e:
                skipped_counter += 1
                logging.info(
                    "Step {} failed, batch skipped." +
                    " Total skipped: {}".format(current_step, skipped_counter))
                logging.error("Step {} failed. Exception details: {}".format(
                    current_step, str(e)))
                continue

            loss += result['loss'] / self.steps_per_checkpoint
            curr_step_time = (time.time() - start_time)
            step_time += curr_step_time / self.steps_per_checkpoint

            # num_correct = 0

            # step_outputs = result['prediction']
            # grounds = batch['labels']
            # for output, ground in zip(step_outputs, grounds):
            #     if self.use_distance:
            #         incorrect = distance.levenshtein(output, ground)
            #         incorrect = float(incorrect) / len(ground)
            #         incorrect = min(1.0, incorrect)
            #     else:
            #         incorrect = 0 if output == ground else 1
            #     num_correct += 1. - incorrect

            writer.add_summary(result['summaries'], current_step)

            # precision = num_correct / len(batch['labels'])
            step_perplexity = math.exp(
                result['loss']) if result['loss'] < 300 else float('inf')

            # logging.info('Step %i: %.3fs, precision: %.2f, loss: %f, perplexity: %f.'
            #              % (current_step, curr_step_time, precision*100,
            #                 result['loss'], step_perplexity))

            logging.info('Step %i: %.3fs, loss: %f, perplexity: %f.',
                         current_step, curr_step_time, result['loss'],
                         step_perplexity)

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % self.steps_per_checkpoint == 0:
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                # Print statistics for the previous epoch.
                logging.info(
                    "Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                    self.sess.run(self.global_step), step_time, loss,
                    perplexity)
                # Save checkpoint and reset timer and loss.
                logging.info("Saving the model at step %d.", current_step)
                self.saver_all.save(self.sess,
                                    self.checkpoint_path,
                                    global_step=self.global_step)
                step_time, loss = 0.0, 0.0

        # Print statistics for the previous epoch.
        perplexity = math.exp(loss) if loss < 300 else float('inf')
        logging.info("Global step %d. Time: %.3f, loss: %f, perplexity: %.2f.",
                     self.sess.run(self.global_step), step_time, loss,
                     perplexity)

        if skipped_counter:
            logging.info(
                "Skipped {} batches due to errors.".format(skipped_counter))

        # Save checkpoint and reset timer and loss.
        logging.info("Finishing the training and saving the model at step %d.",
                     current_step)
        self.saver_all.save(self.sess,
                            self.checkpoint_path,
                            global_step=self.global_step)
예제 #20
0
파일: train.py 프로젝트: jfyang-cn/tf-learn
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    tf_config.gpu_options.per_process_gpu_memory_fraction = 0.3
elif tf.__version__ == '1.11.0' or tf.__version__ == '1.13.2':
    from tensorflow import ConfigProto
    from tensorflow import InteractiveSession

    tf_config = ConfigProto(allow_soft_placement=True)
    tf_config.gpu_options.allow_growth = True

    tf_config.gpu_options.per_process_gpu_memory_fraction = 0.9

save_dir = './results'
filepath = '/home/philyang/git/dataset/helmet/train_mask.txt'
dataGen = DataGen(filepath)

# define checkpoint
dataset_name = dataGen.name()
dirname = 'ckpt-' + dataset_name
if not os.path.exists(dirname):
    os.makedirs(dirname)

timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
filepath = os.path.join(dirname,
                        'weights-%s-{epoch:02d}-{loss:.2f}.hdf5' % (timestr))
checkpoint = ModelCheckpoint(
    filepath=filepath,
    monitor='loss',  # acc outperforms loss
    verbose=1,
    save_best_only=True,