示例#1
0
def main():
    tf.logging.set_verbosity(tf.logging.INFO)
    hparams = create_domain_adapt_hparams()
    for path in [args.train_log_dir]:
        if not tf.gfile.Exists(path):
            tf.gfile.MakeDirs(path)
    hparams_filename = os.path.join(args.train_log_dir, 'hparams.json')
    with tf.gfile.FastGFile(hparams_filename, 'w') as f:
        f.write(hparams.to_json())
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter(args.task_id)):
            global_step = tf.train.get_or_create_global_step()

            images_p_t, class_labels_p_t, theta_labels_p_t = get_dataset(
                os.path.join(args.target_dir, 'positive'), args.num_readers,
                args.num_preprocessing_threads, hparams)
            images_n_t, class_labels_n_t, theta_labels_n_t = get_dataset(
                os.path.join(args.target_dir, 'negative'), args.num_readers,
                args.num_preprocessing_threads, hparams)
            images_t = tf.concat([images_p_t, images_n_t], axis=0)
            class_labels_t = tf.concat([class_labels_p_t, class_labels_n_t],
                                       axis=0)
            theta_labels_t = tf.concat([theta_labels_p_t, theta_labels_n_t],
                                       axis=0)
            with slim.arg_scope(model_arg_scope()):
                net_t, end_points_t = model(
                    inputs=images_t,
                    num_classes=num_classes,
                    is_training=True,
                    dropout_keep_prob=hparams.dropout_keep_prob,
                    reuse=tf.AUTO_REUSE,
                    scope=hparams.scope,
                    adapt_scope='adapt_layer',
                    adapt_dims=128)

            images_p_s, class_labels_p_s, theta_labels_p_s = get_dataset(
                os.path.join(args.source_dir, 'positive'), args.num_readers,
                args.num_preprocessing_threads, hparams)
            images_n_s, class_labels_n_s, theta_labels_n_s = get_dataset(
                os.path.join(args.source_dir, 'negative'), args.num_readers,
                args.num_preprocessing_threads, hparams)
            images_s = tf.concat([images_p_s, images_n_s], axis=0)
            class_labels_s = tf.concat([class_labels_p_s, class_labels_n_s],
                                       axis=0)
            theta_labels_s = tf.concat([theta_labels_p_s, theta_labels_n_s],
                                       axis=0)
            with slim.arg_scope(model_arg_scope()):
                net_s, end_points_s = model(
                    inputs=images_s,
                    num_classes=num_classes,
                    is_training=True,
                    dropout_keep_prob=hparams.dropout_keep_prob,
                    reuse=tf.AUTO_REUSE,
                    scope=hparams.scope,
                    adapt_scope='adapt_layer',
                    adapt_dims=128)

            net = tf.concat([net_t, net_s], axis=0)
            images = tf.concat([images_t, images_s], axis=0)
            class_labels = tf.concat([class_labels_t, class_labels_s], axis=0)
            theta_labels = tf.concat([theta_labels_t, theta_labels_s], axis=0)
            end_points = {}
            end_points_t[hparams.scope +
                         '/target_adapt_layer'] = end_points_t[hparams.scope +
                                                               '/adapt_layer']
            end_points_s[hparams.scope +
                         '/source_adapt_layer'] = end_points_s[hparams.scope +
                                                               '/adapt_layer']
            end_points.update(end_points_t)
            end_points.update(end_points_s)
            loss, accuracy = create_loss(
                net,
                end_points,
                class_labels,
                theta_labels,
                scope=hparams.scope,
                source_adapt_scope='source_adapt_layer',
                target_adapt_scope='target_adapt_layer')
            learning_rate = hparams.learning_rate
            if hparams.lr_decay_step:
                learning_rate = tf.train.exponential_decay(
                    hparams.learning_rate,
                    tf.train.get_or_create_global_step(),
                    decay_steps=hparams.lr_decay_step,
                    decay_rate=hparams.lr_decay_rate,
                    staircase=True)
            tf.summary.scalar('Learning_rate', learning_rate)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            train_op = slim.learning.create_train_op(loss, optimizer)
            add_summary(images,
                        end_points,
                        loss,
                        accuracy,
                        scope='domain_adapt')
            summary_op = tf.summary.merge_all()
            variable_map = restore_map(
                from_adapt_checkpoint=args.from_adapt_checkpoint,
                scope=hparams.scope,
                model_name='source_only',
                checkpoint_exclude_scopes=['adapt_layer', 'fc8'])
            init_saver = tf.train.Saver(variable_map)

            def initializer_fn(sess):
                init_saver.restore(
                    sess, tf.train.latest_checkpoint(args.checkpoint_dir))
                tf.logging.info('Successfully load pretrained checkpoint.')

            init_fn = initializer_fn
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)
            session_config.gpu_options.allow_growth = True
            saver = tf.train.Saver(
                keep_checkpoint_every_n_hours=args.save_interval_secs,
                max_to_keep=200)

            slim.learning.train(train_op,
                                logdir=args.train_log_dir,
                                master=args.master,
                                global_step=global_step,
                                session_config=session_config,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                number_of_steps=args.num_steps,
                                startup_delay_steps=15,
                                save_summaries_secs=args.save_summaries_steps,
                                saver=saver)
示例#2
0
def train(run_dir,
          master,
          task_id,
          num_readers,
          from_graspnet_checkpoint,
          dataset_dir,
          checkpoint_dir,
          save_summaries_steps,
          save_interval_secs,
          num_preprocessing_threads,
          num_steps,
          hparams,
          scope='graspnet'):
    for path in [run_dir]:
        if not tf.gfile.Exists(path):
            tf.gfile.Makedirs(path)
    hparams_filename = os.path.join(run_dir, 'hparams.json')
    with tf.gfile.FastGFile(hparams_filename, 'w') as f:
        f.write(hparams.to_json())
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter(task_id)):
            global_step = slim.get_or_create_global_step()
            images, class_labels, theta_labels = get_dataset(
                dataset_dir, num_readers, num_preprocessing_threads, hparams)
            '''
            with slim.arg_scope(vgg.vgg_arg_scope()):
                net, end_points = vgg.vgg_16(inputs=images,
                                             num_classes=num_classes,
                                             is_training=True,
                                             dropout_keep_prob=0.7,
                                             scope=scope)
            '''
            with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
                net, end_points = alexnet.alexnet_v2(inputs=images,
                                                     num_classes=num_classes,
                                                     is_training=True,
                                                     dropout_keep_prob=0.7,
                                                     scope=scope)
            loss, accuracy = create_loss(net, class_labels, theta_labels)
            learning_rate = hparams.learning_rate
            if hparams.lr_decay_step:
                learning_rate = tf.train.exponential_decay(
                    hparams.learning_rate,
                    slim.get_or_create_global_step(),
                    decay_steps=hparams.lr_decay_step,
                    decay_rate=hparams.lr_decay_rate,
                    staircase=True)
            tf.summary.scalar('Learning_rate', learning_rate)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            train_op = slim.learning.create_train_op(loss, optimizer)
            add_summary(images, end_points, loss, accuracy, scope=scope)
            summary_op = tf.summary.merge_all()
            variable_map = restore_map(
                from_graspnet_checkpoint=from_graspnet_checkpoint,
                scope=scope,
                model_name=hparams.model_name,
                checkpoint_exclude_scope='fc8')
            init_saver = tf.train.Saver(variable_map)

            def initializer_fn(sess):
                init_saver.restore(sess, checkpoint_dir)
                tf.logging.info('Successfully load pretrained checkpoint.')

            init_fn = initializer_fn
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)
            session_config.gpu_options.allow_growth = True
            saver = tf.train.Saver(
                keep_checkpoint_every_n_hours=save_interval_secs,
                max_to_keep=100)

            slim.learning.train(
                train_op,
                logdir=run_dir,
                master=master,
                global_step=global_step,
                session_config=session_config,
                # init_fn=init_fn,
                summary_op=summary_op,
                number_of_steps=num_steps,
                startup_delay_steps=15,
                save_summaries_secs=save_summaries_steps,
                saver=saver)