def main(): tf.logging.set_verbosity(tf.logging.INFO) hparams = create_domain_adapt_hparams() for path in [args.train_log_dir]: if not tf.gfile.Exists(path): tf.gfile.MakeDirs(path) hparams_filename = os.path.join(args.train_log_dir, 'hparams.json') with tf.gfile.FastGFile(hparams_filename, 'w') as f: f.write(hparams.to_json()) with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(args.task_id)): global_step = tf.train.get_or_create_global_step() images_p_t, class_labels_p_t, theta_labels_p_t = get_dataset( os.path.join(args.target_dir, 'positive'), args.num_readers, args.num_preprocessing_threads, hparams) images_n_t, class_labels_n_t, theta_labels_n_t = get_dataset( os.path.join(args.target_dir, 'negative'), args.num_readers, args.num_preprocessing_threads, hparams) images_t = tf.concat([images_p_t, images_n_t], axis=0) class_labels_t = tf.concat([class_labels_p_t, class_labels_n_t], axis=0) theta_labels_t = tf.concat([theta_labels_p_t, theta_labels_n_t], axis=0) with slim.arg_scope(model_arg_scope()): net_t, end_points_t = model( inputs=images_t, num_classes=num_classes, is_training=True, dropout_keep_prob=hparams.dropout_keep_prob, reuse=tf.AUTO_REUSE, scope=hparams.scope, adapt_scope='adapt_layer', adapt_dims=128) images_p_s, class_labels_p_s, theta_labels_p_s = get_dataset( os.path.join(args.source_dir, 'positive'), args.num_readers, args.num_preprocessing_threads, hparams) images_n_s, class_labels_n_s, theta_labels_n_s = get_dataset( os.path.join(args.source_dir, 'negative'), args.num_readers, args.num_preprocessing_threads, hparams) images_s = tf.concat([images_p_s, images_n_s], axis=0) class_labels_s = tf.concat([class_labels_p_s, class_labels_n_s], axis=0) theta_labels_s = tf.concat([theta_labels_p_s, theta_labels_n_s], axis=0) with slim.arg_scope(model_arg_scope()): net_s, end_points_s = model( inputs=images_s, num_classes=num_classes, is_training=True, dropout_keep_prob=hparams.dropout_keep_prob, reuse=tf.AUTO_REUSE, scope=hparams.scope, adapt_scope='adapt_layer', adapt_dims=128) net = tf.concat([net_t, net_s], axis=0) images = tf.concat([images_t, images_s], axis=0) class_labels = tf.concat([class_labels_t, class_labels_s], axis=0) theta_labels = tf.concat([theta_labels_t, theta_labels_s], axis=0) end_points = {} end_points_t[hparams.scope + '/target_adapt_layer'] = end_points_t[hparams.scope + '/adapt_layer'] end_points_s[hparams.scope + '/source_adapt_layer'] = end_points_s[hparams.scope + '/adapt_layer'] end_points.update(end_points_t) end_points.update(end_points_s) loss, accuracy = create_loss( net, end_points, class_labels, theta_labels, scope=hparams.scope, source_adapt_scope='source_adapt_layer', target_adapt_scope='target_adapt_layer') learning_rate = hparams.learning_rate if hparams.lr_decay_step: learning_rate = tf.train.exponential_decay( hparams.learning_rate, tf.train.get_or_create_global_step(), decay_steps=hparams.lr_decay_step, decay_rate=hparams.lr_decay_rate, staircase=True) tf.summary.scalar('Learning_rate', learning_rate) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = slim.learning.create_train_op(loss, optimizer) add_summary(images, end_points, loss, accuracy, scope='domain_adapt') summary_op = tf.summary.merge_all() variable_map = restore_map( from_adapt_checkpoint=args.from_adapt_checkpoint, scope=hparams.scope, model_name='source_only', checkpoint_exclude_scopes=['adapt_layer', 'fc8']) init_saver = tf.train.Saver(variable_map) def initializer_fn(sess): init_saver.restore( sess, tf.train.latest_checkpoint(args.checkpoint_dir)) tf.logging.info('Successfully load pretrained checkpoint.') init_fn = initializer_fn session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True saver = tf.train.Saver( keep_checkpoint_every_n_hours=args.save_interval_secs, max_to_keep=200) slim.learning.train(train_op, logdir=args.train_log_dir, master=args.master, global_step=global_step, session_config=session_config, init_fn=init_fn, summary_op=summary_op, number_of_steps=args.num_steps, startup_delay_steps=15, save_summaries_secs=args.save_summaries_steps, saver=saver)
def train(run_dir, master, task_id, num_readers, from_graspnet_checkpoint, dataset_dir, checkpoint_dir, save_summaries_steps, save_interval_secs, num_preprocessing_threads, num_steps, hparams, scope='graspnet'): for path in [run_dir]: if not tf.gfile.Exists(path): tf.gfile.Makedirs(path) hparams_filename = os.path.join(run_dir, 'hparams.json') with tf.gfile.FastGFile(hparams_filename, 'w') as f: f.write(hparams.to_json()) with tf.Graph().as_default(): with tf.device(tf.train.replica_device_setter(task_id)): global_step = slim.get_or_create_global_step() images, class_labels, theta_labels = get_dataset( dataset_dir, num_readers, num_preprocessing_threads, hparams) ''' with slim.arg_scope(vgg.vgg_arg_scope()): net, end_points = vgg.vgg_16(inputs=images, num_classes=num_classes, is_training=True, dropout_keep_prob=0.7, scope=scope) ''' with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): net, end_points = alexnet.alexnet_v2(inputs=images, num_classes=num_classes, is_training=True, dropout_keep_prob=0.7, scope=scope) loss, accuracy = create_loss(net, class_labels, theta_labels) learning_rate = hparams.learning_rate if hparams.lr_decay_step: learning_rate = tf.train.exponential_decay( hparams.learning_rate, slim.get_or_create_global_step(), decay_steps=hparams.lr_decay_step, decay_rate=hparams.lr_decay_rate, staircase=True) tf.summary.scalar('Learning_rate', learning_rate) optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = slim.learning.create_train_op(loss, optimizer) add_summary(images, end_points, loss, accuracy, scope=scope) summary_op = tf.summary.merge_all() variable_map = restore_map( from_graspnet_checkpoint=from_graspnet_checkpoint, scope=scope, model_name=hparams.model_name, checkpoint_exclude_scope='fc8') init_saver = tf.train.Saver(variable_map) def initializer_fn(sess): init_saver.restore(sess, checkpoint_dir) tf.logging.info('Successfully load pretrained checkpoint.') init_fn = initializer_fn session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_config.gpu_options.allow_growth = True saver = tf.train.Saver( keep_checkpoint_every_n_hours=save_interval_secs, max_to_keep=100) slim.learning.train( train_op, logdir=run_dir, master=master, global_step=global_step, session_config=session_config, # init_fn=init_fn, summary_op=summary_op, number_of_steps=num_steps, startup_delay_steps=15, save_summaries_secs=save_summaries_steps, saver=saver)