Ejemplo n.º 1
0
def get_model():
    if model_index == 0:
        return mobilenet_v1.MobileNetV1()
    elif model_index == 1:
        return mobilenet_v2.MobileNetV2()
    elif model_index == 2:
        return mobilenet_v3_large.MobileNetV3Large()
    elif model_index == 3:
        return mobilenet_v3_small.MobileNetV3Small()
    elif model_index == 4:
        return efficientnet.efficient_net_b0()
    elif model_index == 5:
        return efficientnet.efficient_net_b1()
    elif model_index == 6:
        return efficientnet.efficient_net_b2()
    elif model_index == 7:
        return efficientnet.efficient_net_b3()
    elif model_index == 8:
        return efficientnet.efficient_net_b4()
    elif model_index == 9:
        return efficientnet.efficient_net_b5()
    elif model_index == 10:
        return efficientnet.efficient_net_b6()
    elif model_index == 11:
        return efficientnet.efficient_net_b7()
    elif model_index == 12:
        return resnext.ResNeXt50()
    elif model_index == 13:
        return resnext.ResNeXt101()
    elif model_index == 14:
        return inception_v4.InceptionV4()
    elif model_index == 15:
        return inception_resnet_v1.InceptionResNetV1()
    elif model_index == 16:
        return inception_resnet_v2.InceptionResNetV2()
    elif model_index == 17:
        return se_resnet.se_resnet_50()
    elif model_index == 18:
        return se_resnet.se_resnet_101()
    elif model_index == 19:
        return se_resnet.se_resnet_152()
    elif model_index == 20:
        return squeezenet.SqueezeNet()
    elif model_index == 21:
        return densenet.densenet_121()
    elif model_index == 22:
        return densenet.densenet_169()
    elif model_index == 23:
        return densenet.densenet_201()
    elif model_index == 24:
        return densenet.densenet_264()
    elif model_index == 25:
        return shufflenet_v2.shufflenet_0_5x()
    elif model_index == 26:
        return shufflenet_v2.shufflenet_1_0x()
    elif model_index == 27:
        return shufflenet_v2.shufflenet_1_5x()
    elif model_index == 28:
        return shufflenet_v2.shufflenet_2_0x()
Ejemplo n.º 2
0
def train_resnet(device, model):
    """
  Loads training and validations tf records and trains resnet model and validates every number of fixed steps.
  Input: device - gpu device number 
         model - name of deep learning model, options inclde: se_resnet_101 and densenet_121
  Output: None
  """
    os.environ['CUDA_VISIBLE_DEVICES'] = str(
        device)  # use nvidia-smi to see available options '0' means first gpu
    config = XRAYconfig(
    )  # loads XRAYconfig configuration defined in resnet_config
    # load training data
    train_meta = np.load(tfrecord2metafilename(config.train_fn))
    print('Using train tfrecords: {0} | {1} images'.format(
        config.train_fn, len(train_meta['labels'])))
    train_filename_queue = tf.train.string_input_producer(
        [config.train_fn], num_epochs=config.num_train_epochs)
    # load validation data
    val_meta = np.load(tfrecord2metafilename(config.val_fn))
    print('Using test tfrecords: {0} | {1} images'.format(
        config.val_fn, len(val_meta['labels'])))
    val_filename_queue = tf.train.string_input_producer(
        [config.val_fn], num_epochs=config.num_train_epochs)

    model_train_name = model
    dt_stamp = time.strftime(model_train_name + "_%Y_%m_%d_%H_%M_%S")
    out_dir = config.get_results_path(model_train_name, dt_stamp)
    summary_dir = config.get_summaries_path(model_train_name, dt_stamp)
    print('-' * 60)
    print('Training model: {0}'.format(dt_stamp))
    print('-' * 60)

    # decoding training tfrecords
    train_img, train_t_l, train_b_t, _ = read_and_decode(
        filename_queue=train_filename_queue,
        img_dims=config.input_image_size,
        model_dims=config.model_image_size,
        size_of_batch=config.train_batch_size,
        augmentations_dic=config.train_augmentations_dic,
        num_of_threads=4,
        shuffle=True)

    # decoding validation tfrecords
    val_img, val_t_l, val_b_t, _ = read_and_decode(
        filename_queue=val_filename_queue,
        img_dims=config.input_image_size,
        model_dims=config.model_image_size,
        size_of_batch=config.val_batch_size,
        augmentations_dic=config.val_augmentations_dic,
        num_of_threads=4,
        shuffle=False)

    # summaries to use with tensorboard check https://www.tensorflow.org/get_started/summaries_and_tensorboard
    tf.summary.image('train images', train_img, max_outputs=10)
    tf.summary.image('validation images', val_img, max_outputs=10)

    # creating step op that counts the number of training steps
    step = tf.train.get_or_create_global_step()
    step_op = tf.assign(step, step + 1)

    if model == 'se_resnet_101':
        print("Loading Resnet 101...")
        with tf.variable_scope('resnet_v2_101') as resnet_scope:
            with tf.name_scope('train') as train_scope:
                train_img = imagenet_preprocessing(train_img)
                with slim.arg_scope(
                        se_resnet.resnet_arg_scope(
                            weight_decay=config.l2_reg,
                            batch_norm_decay=config.batch_norm_decay,
                            batch_norm_epsilon=config.batch_norm_epsilon)):
                    train_target_logits, _ = se_resnet.se_resnet_101(
                        inputs=train_img,
                        num_classes=config.output_shape,
                        scope=resnet_scope,
                        is_training=True)

            resnet_scope.reuse_variables()
            with tf.name_scope('val') as val_scope:
                val_img = imagenet_preprocessing(val_img)
                with slim.arg_scope(
                        se_resnet.resnet_arg_scope(
                            weight_decay=config.l2_reg,
                            batch_norm_decay=config.batch_norm_decay,
                            batch_norm_epsilon=config.batch_norm_epsilon)):
                    val_target_logits, _ = se_resnet.se_resnet_101(
                        inputs=val_img,
                        num_classes=config.output_shape,
                        scope=resnet_scope,
                        is_training=False)
    elif model == 'densenet_121':
        print("Loading Densenet 121...")
        with tf.variable_scope('densenet121') as densenet_scope:
            with tf.name_scope('train') as train_scope:
                train_img = imagenet_preprocessing(train_img)
                with slim.arg_scope(
                        densenet.densenet_arg_scope(
                            weight_decay=config.l2_reg,
                            batch_norm_decay=config.batch_norm_decay,
                            batch_norm_epsilon=config.batch_norm_epsilon)):
                    train_target_logits, _ = densenet.densenet121(
                        inputs=train_img,
                        num_classes=config.output_shape,
                        is_training=True,
                        scope=densenet_scope)
            print_model_variables()
            densenet_scope.reuse_variables()
            with tf.name_scope('val') as val_scope:
                val_img = imagenet_preprocessing(val_img)
                with slim.arg_scope(
                        densenet.densenet_arg_scope(
                            weight_decay=config.l2_reg,
                            batch_norm_decay=config.batch_norm_decay,
                            batch_norm_epsilon=config.batch_norm_epsilon)):
                    val_target_logits, _ = densenet.densenet121(
                        inputs=val_img,
                        num_classes=config.output_shape,
                        is_training=False,
                        scope=densenet_scope)
    else:
        raise Exception(
            'Model not implemented! Options are resnet_50 and densenet_121')

    loss = weighted_softmax_cross_entropy_with_logits(
        train_t_l, train_target_logits, config.output_shape,
        'target_class_weights.npy')

    tf.summary.scalar("loss", loss)

    lr = tf.train.exponential_decay(
        learning_rate=config.initial_learning_rate,
        global_step=step_op,
        decay_steps=config.decay_steps,
        decay_rate=config.learning_rate_decay_factor,
        staircase=True
    )  # if staircase is True decay the learning rate at discrete intervals

    if config.optimizer == "adam":
        update_ops = tf.get_collection(
            tf.GraphKeys.UPDATE_OPS)  # used to update batch norm params.
        # see https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(lr).minimize(loss)
    elif config.optimizer == "sgd":
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss)
    elif config.optimizer == "nestrov":
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.MomentumOptimizer(
                lr, config.momentum, use_nesterov=True).minimize(loss)
    else:
        raise Exception(
            "Not known optimizer! options are adam, sgd or nestrov")

    train_prob = tf.nn.softmax(train_target_logits)
    train_pred = tf.argmax(train_prob, 1)

    val_prob = tf.nn.softmax(val_target_logits)
    val_pred = tf.argmax(val_prob, 1)

    train_accuracy = tf.contrib.metrics.accuracy(train_pred, train_t_l)
    val_accuracy = tf.contrib.metrics.accuracy(val_pred, val_t_l)

    train_auc, train_auc_op = tf.metrics.auc(train_t_l, train_pred)
    val_auc, val_auc_op = tf.metrics.auc(val_t_l, val_pred)

    tf.summary.scalar("training accuracy", train_accuracy)
    tf.summary.scalar("validation accuracy", val_accuracy)
    tf.summary.scalar("training auc", train_auc)
    tf.summary.scalar("validation auc", val_auc)

    if config.restore:
        # adjusting variables to keep in the model
        # variables that are exluded will allow for transfer learning (normally fully connected layers are excluded)
        exclusions = [
            scope.strip() for scope in config.checkpoint_exclude_scopes
        ]
        variables_to_restore = []
        for var in slim.get_model_variables():
            excluded = False
            for exclusion in exclusions:
                if var.op.name.startswith(exclusion):
                    excluded = True
                    break
            if not excluded:
                variables_to_restore.append(var)
        print("Restroing variables:")
        for var in variables_to_restore:
            print(var)
        restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(slim.get_model_variables(), max_to_keep=100)
    summary_op = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))

        if config.restore:
            restorer.restore(sess, config.model_path)

        summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        np.save(os.path.join(out_dir, 'training_config_file'), config)

        val_acc_max = 0

        try:

            while not coord.should_stop():

                start_time = time.time()

                step_count, loss_value, train_acc_value, lr_value, _ = sess.run(
                    [step_op, loss, train_accuracy, lr, train_op])
                sess.run(train_auc_op)
                train_auc_value = sess.run(train_auc)

                duration = time.time() - start_time
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'
                step_count = step_count - 1

                if step_count % config.validate_every_num_steps == 0:
                    it_val_acc = np.asarray([])
                    for num_vals in range(config.num_batches_to_validate_over):
                        # Validation accuracy as the average of n batches
                        it_val_acc = np.append(it_val_acc,
                                               sess.run(val_accuracy))
                        sess.run(val_auc_op)

                    val_acc_value = it_val_acc.mean()
                    val_auc_value = sess.run(val_auc)
                    # Summaries
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step_count)

                    # Training status and validation accuracy
                    msg = '{0}: step {1}, loss = {2:.4f} ({3:.2f} examples/sec; '\
                        + '{4:.2f} sec/batch) | Training accuracy = {5:.4f} | Training AUC = {6:.4f} '\
                        + '| Validation accuracy = {7:.4f} | Validation AUC = {8:.4f}| logdir = {9}'
                    print(
                        msg.format(datetime.datetime.now(), step_count,
                                   loss_value,
                                   (config.train_batch_size / duration),
                                   float(duration), train_acc_value,
                                   train_auc_value, val_acc_value,
                                   val_auc_value, summary_dir))
                    print("Learning rate: {}".format(lr_value))
                    # Save the model checkpoint if it's the best yet
                    if val_acc_value >= val_acc_max:
                        file_name = '{0}_{1}'.format(dt_stamp, step_count)
                        saver.save(
                            sess,
                            config.get_checkpoint_filename(
                                model_train_name, file_name))
                        # Store the new max validation accuracy
                        val_acc_max = val_acc_value

                else:
                    # Training status
                    msg = '{0}: step {1}, loss = {2:.4f} ({3:.2f} examples/sec; '\
                        + '{4:.2f} sec/batch) | Training accuracy = {5:.4f} | Training AUC = {6:.4f}'
                    print(
                        msg.format(datetime.datetime.now(), step_count,
                                   loss_value,
                                   (config.train_batch_size / duration),
                                   float(duration), train_acc_value,
                                   train_auc_value))
                # End iteration

        except tf.errors.OutOfRangeError:
            print('Done training for {0} epochs, {1} steps.'.format(
                config.num_train_epochs, step_count))
        finally:
            coord.request_stop()
        coord.join(threads)
Ejemplo n.º 3
0
def get_model():
    if model_index == 0:
        return mobilenet_v1.MobileNetV1()
    elif model_index == 1:
        return mobilenet_v2.MobileNetV2()
    elif model_index == 2:
        return mobilenet_v3_large.MobileNetV3Large()
    elif model_index == 3:
        return mobilenet_v3_small.MobileNetV3Small()
    elif model_index == 4:
        return efficientnet.efficient_net_b0()
    elif model_index == 5:
        return efficientnet.efficient_net_b1()
    elif model_index == 6:
        return efficientnet.efficient_net_b2()
    elif model_index == 7:
        return efficientnet.efficient_net_b3()
    elif model_index == 8:
        return efficientnet.efficient_net_b4()
    elif model_index == 9:
        return efficientnet.efficient_net_b5()
    elif model_index == 10:
        return efficientnet.efficient_net_b6()
    elif model_index == 11:
        return efficientnet.efficient_net_b7()
    elif model_index == 12:
        return resnext.ResNeXt50()
    elif model_index == 13:
        return resnext.ResNeXt101()
    elif model_index == 14:
        return inception_v4.InceptionV4()
    elif model_index == 15:
        return inception_resnet_v1.InceptionResNetV1()
    elif model_index == 16:
        return inception_resnet_v2.InceptionResNetV2()
    elif model_index == 17:
        return se_resnet.se_resnet_50()
    elif model_index == 18:
        return se_resnet.se_resnet_101()
    elif model_index == 19:
        return se_resnet.se_resnet_152()
    elif model_index == 20:
        return squeezenet.SqueezeNet()
    elif model_index == 21:
        return densenet.densenet_121()
    elif model_index == 22:
        return densenet.densenet_169()
    elif model_index == 23:
        return densenet.densenet_201()
    elif model_index == 24:
        return densenet.densenet_264()
    elif model_index == 25:
        return shufflenet_v2.shufflenet_0_5x()
    elif model_index == 26:
        return shufflenet_v2.shufflenet_1_0x()
    elif model_index == 27:
        return shufflenet_v2.shufflenet_1_5x()
    elif model_index == 28:
        return shufflenet_v2.shufflenet_2_0x()
    elif model_index == 29:
        return resnet.resnet_18()
    elif model_index == 30:
        return resnet.resnet_34()
    elif model_index == 31:
        return resnet.resnet_50()
    elif model_index == 32:
        return resnet.resnet_101()
    elif model_index == 33:
        return resnet.resnet_152()
    elif model_index == 34:
        return vgg16.VGG16()
    elif model_index == 35:
        return vgg16_mini.VGG16()
    elif model_index == 36:
        return VGG16_self.VGG16()
    elif model_index == 10086:
        return diy_resnet.resnet_50()
    else:
        raise ValueError("The model_index does not exist.")