def get_model(): if model_index == 0: return mobilenet_v1.MobileNetV1() elif model_index == 1: return mobilenet_v2.MobileNetV2() elif model_index == 2: return mobilenet_v3_large.MobileNetV3Large() elif model_index == 3: return mobilenet_v3_small.MobileNetV3Small() elif model_index == 4: return efficientnet.efficient_net_b0() elif model_index == 5: return efficientnet.efficient_net_b1() elif model_index == 6: return efficientnet.efficient_net_b2() elif model_index == 7: return efficientnet.efficient_net_b3() elif model_index == 8: return efficientnet.efficient_net_b4() elif model_index == 9: return efficientnet.efficient_net_b5() elif model_index == 10: return efficientnet.efficient_net_b6() elif model_index == 11: return efficientnet.efficient_net_b7() elif model_index == 12: return resnext.ResNeXt50() elif model_index == 13: return resnext.ResNeXt101() elif model_index == 14: return inception_v4.InceptionV4() elif model_index == 15: return inception_resnet_v1.InceptionResNetV1() elif model_index == 16: return inception_resnet_v2.InceptionResNetV2() elif model_index == 17: return se_resnet.se_resnet_50() elif model_index == 18: return se_resnet.se_resnet_101() elif model_index == 19: return se_resnet.se_resnet_152() elif model_index == 20: return squeezenet.SqueezeNet() elif model_index == 21: return densenet.densenet_121() elif model_index == 22: return densenet.densenet_169() elif model_index == 23: return densenet.densenet_201() elif model_index == 24: return densenet.densenet_264() elif model_index == 25: return shufflenet_v2.shufflenet_0_5x() elif model_index == 26: return shufflenet_v2.shufflenet_1_0x() elif model_index == 27: return shufflenet_v2.shufflenet_1_5x() elif model_index == 28: return shufflenet_v2.shufflenet_2_0x()
def train_resnet(device, model): """ Loads training and validations tf records and trains resnet model and validates every number of fixed steps. Input: device - gpu device number model - name of deep learning model, options inclde: se_resnet_101 and densenet_121 Output: None """ os.environ['CUDA_VISIBLE_DEVICES'] = str( device) # use nvidia-smi to see available options '0' means first gpu config = XRAYconfig( ) # loads XRAYconfig configuration defined in resnet_config # load training data train_meta = np.load(tfrecord2metafilename(config.train_fn)) print('Using train tfrecords: {0} | {1} images'.format( config.train_fn, len(train_meta['labels']))) train_filename_queue = tf.train.string_input_producer( [config.train_fn], num_epochs=config.num_train_epochs) # load validation data val_meta = np.load(tfrecord2metafilename(config.val_fn)) print('Using test tfrecords: {0} | {1} images'.format( config.val_fn, len(val_meta['labels']))) val_filename_queue = tf.train.string_input_producer( [config.val_fn], num_epochs=config.num_train_epochs) model_train_name = model dt_stamp = time.strftime(model_train_name + "_%Y_%m_%d_%H_%M_%S") out_dir = config.get_results_path(model_train_name, dt_stamp) summary_dir = config.get_summaries_path(model_train_name, dt_stamp) print('-' * 60) print('Training model: {0}'.format(dt_stamp)) print('-' * 60) # decoding training tfrecords train_img, train_t_l, train_b_t, _ = read_and_decode( filename_queue=train_filename_queue, img_dims=config.input_image_size, model_dims=config.model_image_size, size_of_batch=config.train_batch_size, augmentations_dic=config.train_augmentations_dic, num_of_threads=4, shuffle=True) # decoding validation tfrecords val_img, val_t_l, val_b_t, _ = read_and_decode( filename_queue=val_filename_queue, img_dims=config.input_image_size, model_dims=config.model_image_size, size_of_batch=config.val_batch_size, augmentations_dic=config.val_augmentations_dic, num_of_threads=4, shuffle=False) # summaries to use with tensorboard check https://www.tensorflow.org/get_started/summaries_and_tensorboard tf.summary.image('train images', train_img, max_outputs=10) tf.summary.image('validation images', val_img, max_outputs=10) # creating step op that counts the number of training steps step = tf.train.get_or_create_global_step() step_op = tf.assign(step, step + 1) if model == 'se_resnet_101': print("Loading Resnet 101...") with tf.variable_scope('resnet_v2_101') as resnet_scope: with tf.name_scope('train') as train_scope: train_img = imagenet_preprocessing(train_img) with slim.arg_scope( se_resnet.resnet_arg_scope( weight_decay=config.l2_reg, batch_norm_decay=config.batch_norm_decay, batch_norm_epsilon=config.batch_norm_epsilon)): train_target_logits, _ = se_resnet.se_resnet_101( inputs=train_img, num_classes=config.output_shape, scope=resnet_scope, is_training=True) resnet_scope.reuse_variables() with tf.name_scope('val') as val_scope: val_img = imagenet_preprocessing(val_img) with slim.arg_scope( se_resnet.resnet_arg_scope( weight_decay=config.l2_reg, batch_norm_decay=config.batch_norm_decay, batch_norm_epsilon=config.batch_norm_epsilon)): val_target_logits, _ = se_resnet.se_resnet_101( inputs=val_img, num_classes=config.output_shape, scope=resnet_scope, is_training=False) elif model == 'densenet_121': print("Loading Densenet 121...") with tf.variable_scope('densenet121') as densenet_scope: with tf.name_scope('train') as train_scope: train_img = imagenet_preprocessing(train_img) with slim.arg_scope( densenet.densenet_arg_scope( weight_decay=config.l2_reg, batch_norm_decay=config.batch_norm_decay, batch_norm_epsilon=config.batch_norm_epsilon)): train_target_logits, _ = densenet.densenet121( inputs=train_img, num_classes=config.output_shape, is_training=True, scope=densenet_scope) print_model_variables() densenet_scope.reuse_variables() with tf.name_scope('val') as val_scope: val_img = imagenet_preprocessing(val_img) with slim.arg_scope( densenet.densenet_arg_scope( weight_decay=config.l2_reg, batch_norm_decay=config.batch_norm_decay, batch_norm_epsilon=config.batch_norm_epsilon)): val_target_logits, _ = densenet.densenet121( inputs=val_img, num_classes=config.output_shape, is_training=False, scope=densenet_scope) else: raise Exception( 'Model not implemented! Options are resnet_50 and densenet_121') loss = weighted_softmax_cross_entropy_with_logits( train_t_l, train_target_logits, config.output_shape, 'target_class_weights.npy') tf.summary.scalar("loss", loss) lr = tf.train.exponential_decay( learning_rate=config.initial_learning_rate, global_step=step_op, decay_steps=config.decay_steps, decay_rate=config.learning_rate_decay_factor, staircase=True ) # if staircase is True decay the learning rate at discrete intervals if config.optimizer == "adam": update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # used to update batch norm params. # see https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization with tf.control_dependencies(update_ops): train_op = tf.train.AdamOptimizer(lr).minimize(loss) elif config.optimizer == "sgd": update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss) elif config.optimizer == "nestrov": update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = tf.train.MomentumOptimizer( lr, config.momentum, use_nesterov=True).minimize(loss) else: raise Exception( "Not known optimizer! options are adam, sgd or nestrov") train_prob = tf.nn.softmax(train_target_logits) train_pred = tf.argmax(train_prob, 1) val_prob = tf.nn.softmax(val_target_logits) val_pred = tf.argmax(val_prob, 1) train_accuracy = tf.contrib.metrics.accuracy(train_pred, train_t_l) val_accuracy = tf.contrib.metrics.accuracy(val_pred, val_t_l) train_auc, train_auc_op = tf.metrics.auc(train_t_l, train_pred) val_auc, val_auc_op = tf.metrics.auc(val_t_l, val_pred) tf.summary.scalar("training accuracy", train_accuracy) tf.summary.scalar("validation accuracy", val_accuracy) tf.summary.scalar("training auc", train_auc) tf.summary.scalar("validation auc", val_auc) if config.restore: # adjusting variables to keep in the model # variables that are exluded will allow for transfer learning (normally fully connected layers are excluded) exclusions = [ scope.strip() for scope in config.checkpoint_exclude_scopes ] variables_to_restore = [] for var in slim.get_model_variables(): excluded = False for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break if not excluded: variables_to_restore.append(var) print("Restroing variables:") for var in variables_to_restore: print(var) restorer = tf.train.Saver(variables_to_restore) saver = tf.train.Saver(slim.get_model_variables(), max_to_keep=100) summary_op = tf.summary.merge_all() with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run( tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) if config.restore: restorer.restore(sess, config.model_path) summary_writer = tf.summary.FileWriter(summary_dir, sess.graph) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) np.save(os.path.join(out_dir, 'training_config_file'), config) val_acc_max = 0 try: while not coord.should_stop(): start_time = time.time() step_count, loss_value, train_acc_value, lr_value, _ = sess.run( [step_op, loss, train_accuracy, lr, train_op]) sess.run(train_auc_op) train_auc_value = sess.run(train_auc) duration = time.time() - start_time assert not np.isnan( loss_value), 'Model diverged with loss = NaN' step_count = step_count - 1 if step_count % config.validate_every_num_steps == 0: it_val_acc = np.asarray([]) for num_vals in range(config.num_batches_to_validate_over): # Validation accuracy as the average of n batches it_val_acc = np.append(it_val_acc, sess.run(val_accuracy)) sess.run(val_auc_op) val_acc_value = it_val_acc.mean() val_auc_value = sess.run(val_auc) # Summaries summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step_count) # Training status and validation accuracy msg = '{0}: step {1}, loss = {2:.4f} ({3:.2f} examples/sec; '\ + '{4:.2f} sec/batch) | Training accuracy = {5:.4f} | Training AUC = {6:.4f} '\ + '| Validation accuracy = {7:.4f} | Validation AUC = {8:.4f}| logdir = {9}' print( msg.format(datetime.datetime.now(), step_count, loss_value, (config.train_batch_size / duration), float(duration), train_acc_value, train_auc_value, val_acc_value, val_auc_value, summary_dir)) print("Learning rate: {}".format(lr_value)) # Save the model checkpoint if it's the best yet if val_acc_value >= val_acc_max: file_name = '{0}_{1}'.format(dt_stamp, step_count) saver.save( sess, config.get_checkpoint_filename( model_train_name, file_name)) # Store the new max validation accuracy val_acc_max = val_acc_value else: # Training status msg = '{0}: step {1}, loss = {2:.4f} ({3:.2f} examples/sec; '\ + '{4:.2f} sec/batch) | Training accuracy = {5:.4f} | Training AUC = {6:.4f}' print( msg.format(datetime.datetime.now(), step_count, loss_value, (config.train_batch_size / duration), float(duration), train_acc_value, train_auc_value)) # End iteration except tf.errors.OutOfRangeError: print('Done training for {0} epochs, {1} steps.'.format( config.num_train_epochs, step_count)) finally: coord.request_stop() coord.join(threads)
def get_model(): if model_index == 0: return mobilenet_v1.MobileNetV1() elif model_index == 1: return mobilenet_v2.MobileNetV2() elif model_index == 2: return mobilenet_v3_large.MobileNetV3Large() elif model_index == 3: return mobilenet_v3_small.MobileNetV3Small() elif model_index == 4: return efficientnet.efficient_net_b0() elif model_index == 5: return efficientnet.efficient_net_b1() elif model_index == 6: return efficientnet.efficient_net_b2() elif model_index == 7: return efficientnet.efficient_net_b3() elif model_index == 8: return efficientnet.efficient_net_b4() elif model_index == 9: return efficientnet.efficient_net_b5() elif model_index == 10: return efficientnet.efficient_net_b6() elif model_index == 11: return efficientnet.efficient_net_b7() elif model_index == 12: return resnext.ResNeXt50() elif model_index == 13: return resnext.ResNeXt101() elif model_index == 14: return inception_v4.InceptionV4() elif model_index == 15: return inception_resnet_v1.InceptionResNetV1() elif model_index == 16: return inception_resnet_v2.InceptionResNetV2() elif model_index == 17: return se_resnet.se_resnet_50() elif model_index == 18: return se_resnet.se_resnet_101() elif model_index == 19: return se_resnet.se_resnet_152() elif model_index == 20: return squeezenet.SqueezeNet() elif model_index == 21: return densenet.densenet_121() elif model_index == 22: return densenet.densenet_169() elif model_index == 23: return densenet.densenet_201() elif model_index == 24: return densenet.densenet_264() elif model_index == 25: return shufflenet_v2.shufflenet_0_5x() elif model_index == 26: return shufflenet_v2.shufflenet_1_0x() elif model_index == 27: return shufflenet_v2.shufflenet_1_5x() elif model_index == 28: return shufflenet_v2.shufflenet_2_0x() elif model_index == 29: return resnet.resnet_18() elif model_index == 30: return resnet.resnet_34() elif model_index == 31: return resnet.resnet_50() elif model_index == 32: return resnet.resnet_101() elif model_index == 33: return resnet.resnet_152() elif model_index == 34: return vgg16.VGG16() elif model_index == 35: return vgg16_mini.VGG16() elif model_index == 36: return VGG16_self.VGG16() elif model_index == 10086: return diy_resnet.resnet_50() else: raise ValueError("The model_index does not exist.")