예제 #1
0
 def train(self, model):
     self.create_train_data_iterator()
     optimizer = self.build_optimizer()
     model.compile(optimizer=optimizer,
                   loss=self.loss_object,
                   loss_weights=self.loss_weight,
                   metrics=self.metrics_object)
     callbacks = [
         tf.keras.callbacks.TensorBoard(log_dir=self.flags.summaries_dir),
         CSVLogger(self.flags.summaries_dir + '/log.csv',
                   append=True,
                   separator=','),
         LossAndErrorPrintingCallback()
     ]
     if self.flags.profile_batch:
         tb_callback = tf.keras.callbacks.TensorBoard(
             log_dir=self.flags.summaries_dir,
             profile_batch=self.flags.profile_batch)
         callbacks.append(tb_callback)
     if self.flags.patient_valid_passes:
         EarlyStopping = tf.keras.callbacks.EarlyStopping(
             monitor='val_loss',
             patience=self.flags.patient_valid_passes,
             mode='min',
             restore_best_weights=True)
         callbacks.append(EarlyStopping)
     if self.flags.checkpoint_path:
         # Create a callback that saves the model's weights
         cp_callback = tf.keras.callbacks.ModelCheckpoint(
             filepath=self.flags.checkpoint_path,
             save_weights_only=True,
             monitor='val_loss',
             mode='auto',
             save_best_only=True)
         callbacks.append(cp_callback)
     if self.flags.lr_schedule:
         callbacks.append(LearningRateScheduler(self.lr_schedule))
     history = model.fit(self.train_iterator,
                         validation_data=self.valid_iterator,
                         epochs=self.flags.epochs,
                         callbacks=callbacks)
     return history
예제 #2
0
flatten = keras.layers.Flatten(name='Flatten')(feature)
classes = keras.layers.Dense(num_classes, activation='linear', kernel_initializer='he_normal',
                    kernel_regularizer=l2(weight_decay), name='Classes_Congr')(flatten)
outputs = keras.layers.Activation('softmax', name='Output')(classes)

model = keras.models.Model(inputs=inputs, outputs=outputs)
pretrained_model = '%s/trained_models/rcm/congr_rcm_rgb_weights.h5'%model_prefix
print 'Loading pretrained model from %s' % pretrained_model
model.load_weights(pretrained_model, by_name=False)
for i in range(len(model.trainable_weights)):
  print model.trainable_weights[i]

optimizer = keras.optimizers.SGD(lr=0.001, decay=0, momentum=0.9, nesterov=False)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

lr_reducer = LearningRateScheduler(lr_polynomial_decay,train_steps) 
model_checkpoint = ModelCheckpoint(weights_file, monitor="val_acc", 
                   save_best_only=False,save_weights_only=True,mode='auto')
callbacks = [lr_reducer, model_checkpoint]

model.fit_generator(conTrainImageGenerator(training_datalist, 
                                           batch_size, depth, num_classes, RGB),
          steps_per_epoch=train_steps,
          epochs=nb_epoch,
          verbose=1,
          callbacks=callbacks,
          validation_data=conTestImageGenerator(testing_datalist, 
                                                batch_size, depth, num_classes, RGB),
          validation_steps=test_steps,
          initial_epoch=init_epoch,
          )
예제 #3
0
def main():

    model = create_model()

    # Configuration
    batch_size = 64
    target_iterations = 90000  # at batch_size = 64
    base_lr = 0.005

    sgd = SGD(lr=base_lr, momentum=0.9)

    model.compile(optimizer=sgd,
                  loss='mean_squared_error',
                  metrics=[mean_corner_error])
    model.summary()

    save_path = os.path.dirname(os.path.realpath(__file__))
    checkpoint = ModelCheckpoint(
        os.path.join(save_path, 'model.{epoch:02d}.h5'))

    # LR scaling as described in the paper: https://arxiv.org/pdf/1606.03798.pdf
    lr_scheduler = LearningRateScheduler(base_lr, 0.1, 30000)

    # In the paper, the 90,000 iterations was for batch_size = 64
    # So scale appropriately
    target_iterations = int(target_iterations * 64 / batch_size)

    # load data here from ME3
    # separate into test and sample
    # play with the split, but you can start with:
    # Trainig: 80%
    # Testing: 20%
    data_dir = 'pokemon_dataset/warped'
    files = os.listdir(data_dir)
    size = len(files)

    TRAIN_SAMPLES_COUNT = size * 0.8
    TEST_SAMPLES_COUNT = size * 0.2

    # As stated in Keras docs
    steps_per_epoch = int(TRAIN_SAMPLES_COUNT / batch_size)
    epochs = int(math.ceil(target_iterations / steps_per_epoch))

    train_data = []
    test_data = []

    for i in range(math.ceil(TRAIN_SAMPLES_COUNT)):
        temp = np.load(os.getcwd() + '/pokemon_dataset/warped/' + files[i])
        train_data.append(temp['img'])

    for i in range(size - math.ceil(TEST_SAMPLES_COUNT) + 1, size):
        temp = np.load(os.getcwd() + '/pokemon_dataset/warped/' + files[i])
        test_data.append(temp['img'])
    test_steps = int(TEST_SAMPLES_COUNT / batch_size)

    # Train
    model.fit_generator(train_data,
                        steps_per_epoch,
                        epochs,
                        callbacks=[lr_scheduler, checkpoint],
                        validation_data=test_data,
                        validation_steps=test_steps)
예제 #4
0
                                momentum=args.momentum)

    logger.info('Number of model parameters: {:,}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    trainer = Trainer(model, optimizer, watch=['acc'], val_watch=['acc'])

    if args.is_train:
        logger.info("Train on {} samples, validate on {} samples".format(
            len(train_loader.dataset), len(val_loader.dataset)))
        start_epoch = 0
        if args.resume:
            start_epoch = load_checkpoint(args.ckpt_dir, model, optimizer)

        trainer.train(train_loader,
                      val_loader,
                      start_epoch=start_epoch,
                      epochs=args.epochs,
                      callbacks=[
                          PlotCbk(model, args.plot_num_imgs, args.plot_freq,
                                  args.use_gpu),
                          TensorBoard(model, args.log_dir),
                          ModelCheckpoint(model, optimizer, args.ckpt_dir),
                          LearningRateScheduler(
                              ReduceLROnPlateau(optimizer, 'min'), 'val_loss'),
                          EarlyStopping(model, patience=args.patience)
                      ])
    else:
        logger.info("Test on {} samples".format((len(test_loader))))
        load_checkpoint(args.ckpt_dir, model, best=True)
        trainer.test(test_loader, best=args.best)
예제 #5
0
                                          best=args.best)

        trainer.train(
            train_loader,
            val_loader,
            start_epoch=start_epoch,
            epochs=args.epochs,
            callbacks=[
                PlotCbk(args.plot_dir, model, args.plot_num_imgs,
                        args.plot_freq, args.use_gpu),
                # TensorBoard(model, args.log_dir),
                ModelCheckpoint(model, optimizer, args.ckpt_dir),
                # LearningRateScheduler(ReduceLROnPlateau(optimizer, 'min'), 'val_loss'),
                LearningRateScheduler(
                    ReduceLROnPlateau(optimizer,
                                      factor=0.1,
                                      patience=5,
                                      mode='min'), 'val_loss'),
                EarlyStopping(model, patience=args.patience)
            ])
    elif args.is_plot:
        dataset = ImageFolder(os.path.join(args.data_dir, args.mode),
                              testtransformSequence)
        loader = get_test_loader(dataset, args.num_plots, **kwargs)
        logger.info("Plotting a random batch from the folder {}".format(
            args.mode))
        start_epoch = load_checkpoint(args.ckpt_dir, model, False, best=True)
        trainer.plot(
            loader,
            PlotCbk(args.plot_dir, model, args.num_plots, 1, args.use_gpu),
            args.plot_name)