示例#1
0
def train(exp_id, files, args):
    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(lambda x: parsers._parse_fun_one_mod(
            x, is_training=True, modality=args.modality))
        seed = tf.placeholder(tf.int64, shape=())  # =epoch
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(lambda x: parsers._parse_fun_one_mod(
                x, is_training=False, modality=args.modality))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(lambda x: parsers._parse_fun_one_mod(
            x, is_training=False, modality=args.modality))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_stacked = next_element[0]  # [batch_sz, time_bottleneck, h,w,c]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[1] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[1]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        stack_shape = tf.shape(images_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        batch_images = tf.reshape(
            images_stacked, [stack_shape[0] * stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(
                resnet_v1.resnet_arg_scope(batch_norm_decay=args.bn_decay)):
            is_training = tf.placeholder(tf.bool, name="is_training")
            nr_frames = parsers.time_bottleneck
            scope = 'resnet_v1_50'

            net_out, net_endpoints = resnet_v1.resnet_one_stream_main(
                batch_images,
                nr_frames,
                num_classes=n_classes,
                scope=scope,
                gpu_id=args.gpu0,
                is_training=is_training)

            # predictions for each video are the avg of frames' predictions
            # TRAIN ###############################
            net_train = tf.reshape(net_out, [-1, nr_frames, n_classes])
            net_train = tf.reduce_mean(net_train, axis=1)
            # TEST ###############################
            net_test = tf.reshape(net_out, [-1, nr_frames, n_classes])
            net_test = tf.reduce_mean(net_test, axis=1)

            # loss ##########################################################
            loss = slim.losses.softmax_cross_entropy(net_train, labels)
            # optimizers ######################################################
            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            minimizing = slim.learning.create_train_op(loss, optimizer)

            acc_train = utils.accuracy(net_train, labels)
            n_correct = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_test, labels), tf.float32))

    summ_loss = tf.summary.scalar('loss', loss)
    summ_acc_train = tf.summary.scalar('acc_train', acc_train)
    summ_train = tf.summary.merge([summ_acc_train, summ_loss])
    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_test = tf.summary.scalar('acc_test', accuracy_value_)
    summ_acc_val = tf.summary.scalar('acc_val', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.just_eval:
            restorers.restore_weights_s1_continue(sess, args.ckpt,
                                                  args.modality)
        else:
            restorers.restore_weights_s1(sess,
                                         files['imagenet_checkpoint_file'])

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ = summ_acc_val
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ = summ_acc_test

            try:
                accum_correct = 0
                while True:
                    n_correct_val = sess.run(n_correct,
                                             feed_dict={
                                                 handle: step_handle,
                                                 is_training: False
                                             })
                    accum_correct += n_correct_val
            except tf.errors.OutOfRangeError:
                step_acc = accum_correct / step_samples
                summary_acc = sess.run(step_summ,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log, 'Accuracy = %s \n' % str(step_acc))
            return step_acc

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:  # get summaries
                        _, summary = sess.run([minimizing, summ_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run(minimizing,
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
示例#2
0
def train(exp_id, files, args):
    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]

        if args.noise > 0:
            images_depth_stacked = images_depth_stacked * \
                tf.random_normal(shape=tf.shape(
                    images_depth_stacked), mean=1, stddev=args.noise, dtype=tf.float32)
            images_depth_stacked = tf.saturate_cast(images_depth_stacked,
                                                    dtype=tf.uint8)
            images_depth_stacked = tf.to_float(images_depth_stacked)
        elif args.noise < 0:
            images_depth_stacked = tf.zeros(tf.shape(images_depth_stacked),
                                            tf.float32)

        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth = resnet_v1.resnet_one_stream_main(
                images_depth,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_depth',
                gpu_id=args.gpu0,
                is_training=is_training,
                bottleneck=True)
            net_rgb_out, endpoints_rgb = resnet_v1.resnet_one_stream_main(
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_rgb',
                gpu_id='/gpu:1',
                is_training=is_training,
                bottleneck=False)

            # predictions for each video are the avg of frames' predictions
            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_rgb_train = tf.reshape(
                net_rgb_out, [-1, parsers.time_bottleneck, n_classes])
            net_rgb_train = tf.reduce_mean(net_rgb_train, axis=1)
            net_combined_train = tf.add(net_depth_train, net_rgb_train) / 2.0

            # TEST ###############################
            net_rgb_test = tf.reshape(net_rgb_out,
                                      [-1, parsers.time_bottleneck, n_classes])
            net_rgb_test = tf.reduce_mean(net_rgb_test, axis=1)
            net_depth_test = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)
            net_combined_test = tf.add(net_rgb_test, net_depth_test) / 2.0

            # losses ##########################################################
            loss_combined = slim.losses.softmax_cross_entropy(
                net_combined_train, labels)
            loss_depth = slim.losses.softmax_cross_entropy(
                net_depth_train, labels)
            loss_rgb = slim.losses.softmax_cross_entropy(net_rgb_train, labels)

            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            minimizing = slim.learning.create_train_op(loss_combined,
                                                       optimizer)

            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_rgb_train = utils.accuracy(net_rgb_train, labels)
            acc_combined_train = utils.accuracy(net_combined_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_rgb = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_rgb_test, labels), tf.float32))
            n_correct_combined = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_combined_test, labels),
                        tf.float32))

    summ_loss_combined = tf.summary.scalar('loss_combined', loss_combined)
    summ_loss_depth = tf.summary.scalar('loss_depth', loss_depth)
    summ_loss_rgb = tf.summary.scalar('loss_rgb', loss_rgb)
    summ_acc_train_rgb = tf.summary.scalar('acc_train_rgb', acc_rgb_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summ_acc_train_combined = tf.summary.scalar('acc_train_combined',
                                                acc_combined_train)
    summary_train = tf.summary.merge([
        summ_acc_train_rgb, summ_acc_train_depth, summ_acc_train_combined,
        summ_loss_depth, summ_loss_rgb, summ_loss_combined
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summary_acc_val_rgb = tf.summary.scalar('acc_val_rgb', accuracy_value_)
    summary_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summary_acc_val_combined = tf.summary.scalar('acc_val_combined',
                                                 accuracy_value_)
    summary_acc_test_rgb = tf.summary.scalar('acc_test_rgb', accuracy_value_)
    summary_acc_test_depth = tf.summary.scalar('acc_test_depth',
                                               accuracy_value_)
    summary_acc_test_combined = tf.summary.scalar('acc_test_combined',
                                                  accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.ckpt == '':
            restorers.restore_weights_s2(sess, s1_rgb_ckpt, s1_depth_ckpt)
        else:
            restorers.restore_weights_s2_continue(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_rgb = summary_acc_val_rgb
                step_summ_depth = summary_acc_val_depth
                step_summ_combined = summary_acc_val_combined
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_rgb = summary_acc_test_rgb
                step_summ_depth = summary_acc_test_depth
                step_summ_combined = summary_acc_test_combined
            try:
                accum_correct_rgb = accum_correct_depth = accum_correct_combined_val = 0
                while True:
                    n_correct_rgb1, n_correct_depth1, n_correct_combined1 = sess.run(
                        [n_correct_rgb, n_correct_depth, n_correct_combined],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_rgb += n_correct_rgb1
                    accum_correct_depth += n_correct_depth1
                    accum_correct_combined_val += n_correct_combined1
            except tf.errors.OutOfRangeError:
                acc_rgb = accum_correct_rgb / step_samples
                acc_depth = accum_correct_depth / step_samples
                acc_combined = accum_correct_combined_val / step_samples
                sum_rgb_acc = sess.run(step_summ_rgb,
                                       feed_dict={accuracy_value_: acc_rgb})
                summary_writer.add_summary(sum_rgb_acc, value_step)
                sum_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(sum_depth_acc, value_step)
                sum_combined_acc = sess.run(
                    step_summ_combined,
                    feed_dict={accuracy_value_: acc_combined})
                summary_writer.add_summary(sum_combined_acc, value_step)
                utils.double_log(f_log,
                                 'Depth accuracy = %s \n' % str(acc_depth))
                utils.double_log(f_log, 'RGB accuracy = %s \n' % str(acc_rgb))
                utils.double_log(
                    f_log, 'combined accuracy = %s \n' % str(acc_combined))
                return acc_combined

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:  # get summaries
                        _, summary = sess.run([minimizing, summary_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run([minimizing],
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d \n" %
                    (best_acc, best_epoch)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
示例#3
0
def train(exp_id, files, args):
    hallucination_layer_true = 'resnet_v1_50_of/block4'
    hallucination_layer_hall = 'resnet_v1_50_hall/block4'

    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
            labels_per_frame = next_element[3]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)
        labels_per_frame = tf.reshape(labels_per_frame, [-1])
        labels_per_frame = tf.one_hot(labels_per_frame, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth = resnet_v1.resnet_one_stream_main(
                images_depth,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_depth',
                gpu_id='/gpu:0',
                is_training=False,
                bottleneck=True)

            net_hall_out, net_hall_endpoints = resnet_v1.resnet_one_stream_main(
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_hall',
                gpu_id='/gpu:1',
                is_training=is_training,
                bottleneck=True)

            nr_batch_frames = tf.shape(labels_per_frame)
            nr_batch_vid = tf.shape(labels)

            temporal_order = utils.get_temporal_order_onehot(
                nr_batch_vid[0], parsers.time_bottleneck)
            temporal_order = tf.expand_dims(temporal_order, axis=1)
            temporal_order = tf.expand_dims(temporal_order, axis=1)

            feat_depth = endpoints_depth['last_pool']
            logits_real = resnet_v1.feature_discriminator(feat_depth,
                                                          temporal_order,
                                                          n_classes=n_classes)

            feat_hall = net_hall_endpoints['last_pool']
            logits_fake = resnet_v1.feature_discriminator(feat_hall,
                                                          temporal_order,
                                                          n_classes=n_classes,
                                                          reuse=True)

            logits_real = tf.squeeze(logits_real, [1, 2])
            logits_fake = tf.squeeze(logits_fake, [1, 2])

            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_hall_train = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_train = tf.reduce_mean(net_hall_train, axis=1)

            # TEST ###############################
            net_hall_test = tf.reshape(net_hall_out,
                                       [-1, utils.time_bottleneck, n_classes])
            net_hall_test = tf.reduce_mean(net_hall_test, axis=1)
            net_depth_test = tf.reshape(net_depth_out,
                                        [-1, utils.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)

            # losses ##########################################################
            d_target_dist_real = tf.concat(
                axis=-1,
                values=[
                    tf.zeros([nr_batch_frames[0], 1], tf.float32),
                    tf.cast(labels_per_frame, tf.float32)
                ])
            d_loss_real = slim.losses.softmax_cross_entropy(
                logits_real, d_target_dist_real)
            d_target_dist_fake = tf.concat(
                axis=-1,
                values=[
                    tf.ones([nr_batch_frames[0], 1], tf.float32),
                    tf.zeros([nr_batch_frames[0], n_classes], tf.float32)
                ])
            d_loss_fake = slim.losses.softmax_cross_entropy(
                logits_fake, d_target_dist_fake)
            d_loss = .5 * (d_loss_real + d_loss_fake)

            g_target_dist_fake = tf.concat(
                axis=-1,
                values=[
                    tf.zeros([nr_batch_frames[0], 1], tf.float32),
                    tf.cast(labels_per_frame, tf.float32)
                ])
            g_loss = slim.losses.softmax_cross_entropy(logits_fake,
                                                       g_target_dist_fake)

            loss_hall_rect_static = utils.loss_hall_rect(
                endpoints_depth[hallucination_layer_true],
                net_hall_endpoints[hallucination_layer_hall])
            loss_hall_rect_static2 = utils.loss_hall_rect(
                endpoints_depth['last_pool'], net_hall_endpoints['last_pool'])

            d_optimizer = tf.train.AdamOptimizer(args.learning_rate)
            g_optimizer = tf.train.AdamOptimizer(args.learning_rate)

            t_vars = tf.trainable_variables()
            # freezing depth
            depth_vars = [x for x in t_vars if 'resnet_v1_50_of' in x.name]
            to_remove = depth_vars
            train_vars = [x for x in t_vars if x not in to_remove]

            train_vars_d = [x for x in train_vars if 'disc_e' in x.name]
            minimizing_d = slim.learning.create_train_op(
                d_loss, d_optimizer, variables_to_train=train_vars_d)

            train_vars_g = [
                x for x in train_vars if 'resnet_v1_50_hall' in x.name
            ]
            minimizing_g = slim.learning.create_train_op(
                g_loss, g_optimizer, variables_to_train=train_vars_g)

            ###################################################################
            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_hall_train = utils.accuracy(net_hall_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_hall = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_hall_test, labels), tf.float32))
            ###################################################################

    summ_d_loss = tf.summary.scalar('d_loss', d_loss)
    summ_d_loss_real = tf.summary.scalar('d_loss_real', d_loss_real)
    summ_d_loss_fake = tf.summary.scalar('d_loss_fake', d_loss_fake)
    summ_g_loss = tf.summary.scalar('g_loss', g_loss)
    summ_loss_hall_rect_static = tf.summary.scalar('loss_euclid_hall_layer',
                                                   loss_hall_rect_static)
    summ_loss_hall_rect_static2 = tf.summary.scalar('loss_euclid_hall_2',
                                                    loss_hall_rect_static2)
    summ_acc_train_hall = tf.summary.scalar('acc_train_hall', acc_hall_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summary_train = tf.summary.merge([
        summ_d_loss, summ_d_loss_real, summ_d_loss_fake, summ_g_loss,
        summ_loss_hall_rect_static, summ_loss_hall_rect_static2,
        summ_acc_train_hall, summ_acc_train_depth
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_val_hall = tf.summary.scalar('acc_val_hall', accuracy_value_)
    summ_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summ_acc_test_hall = tf.summary.scalar('acc_test_hall', accuracy_value_)
    summ_acc_test_depth = tf.summary.scalar('acc_test_depth', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.ckpt == '':
            sys.exit('Please specify the depth checkpoint')
        restorers.restore_weights_s2_5_gan_depth(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_hall = summ_acc_val_hall
                step_summ_depth = summ_acc_val_depth
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_hall = summ_acc_test_hall
                step_summ_depth = summ_acc_test_depth
            try:
                accum_correct_depth = accum_correct_hall = 0
                while True:
                    n_correct_hall1, n_correct_depth1 = sess.run(
                        [n_correct_hall, n_correct_depth],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_depth += n_correct_depth1
                    accum_correct_hall += n_correct_hall1
            except tf.errors.OutOfRangeError:
                acc_hall = accum_correct_hall / step_samples
                acc_depth = accum_correct_depth / step_samples
                summ_hall_acc = sess.run(step_summ_hall,
                                         feed_dict={accuracy_value_: acc_hall})
                summary_writer.add_summary(summ_hall_acc, value_step)
                summ_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(summ_depth_acc, value_step)
                utils.double_log(f_log, 'Hall acc = %s \n' % str(acc_hall))
                utils.double_log(f_log, 'Depth acc = %s \n' % str(acc_depth))
                return acc_hall

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:
                        _, _, summ_train = sess.run(
                            [minimizing_d, minimizing_g, summary_train],
                            feed_dict={
                                handle: train_handle,
                                nr_frames: parsers.time_bottleneck,
                                is_training: True
                            })
                        summary_writer.add_summary(summ_train, n_step)
                    else:
                        sess.run(
                            [minimizing_d, minimizing_g],
                            feed_dict={
                                handle: train_handle,
                                nr_frames: parsers.time_bottleneck,
                                is_training: True
                            })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
示例#4
0
def train(exp_id, files, args):
    hallucination_layer_true = 'resnet_v1_50_depth/block4'
    hallucination_layer_hall = 'resnet_v1_50_hall/block4'
    scope_rgb = 'hall'

    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth, net_hall_out, endpoints_hall = resnet_v1.resnet_twostream_main(
                images_depth,
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope_rgb_stream=scope_rgb,
                depth_training=False,
                is_training=is_training,
                interaction=args.interaction)

            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_hall_train = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_train = tf.reduce_mean(net_hall_train, axis=1)

            # TEST ###############################
            net_hall_test = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_test = tf.reduce_mean(net_hall_test, axis=1)
            net_depth_test = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)

            # losses ##########################################################
            loss_hard_hall_class = slim.losses.softmax_cross_entropy(
                net_hall_train, labels)  # hard labels loss

            T = 10  # Temperature T
            teacher_soft_labels = utils.softmax_distill(net_depth_train, T)
            student_soft_labels = utils.softmax_distill(net_hall_train, T)
            loss_soft_hall_class = utils.cross_entropy(student_soft_labels,
                                                       teacher_soft_labels)

            lambda1 = 0.5  # lambda immitation - contrib of soft_label_entropy
            loss_hall_distill = (1 - lambda1) * loss_hard_hall_class + \
                lambda1 * loss_soft_hall_class

            loss_hall_rect_static = utils.loss_hall_rect(
                endpoints_depth[hallucination_layer_true],
                endpoints_hall[hallucination_layer_hall])

            lambda2 = 0.5  # strength of cross entropy loss
            lambda3 = 0.01
            loss = lambda2 * loss_hall_distill + \
                (1 - lambda2) * loss_hall_rect_static * lambda3

            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # freezing depth
            to_remove = [
                x for x in list(
                    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
                if 'resnet_v1_50_depth/' in x.name
            ]
            train_vars = [
                x for x in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                if x not in to_remove
            ]
            minimizing = slim.learning.create_train_op(
                loss, optimizer, variables_to_train=train_vars)

            ###################################################################
            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_hall_train = utils.accuracy(net_hall_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_hall = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_hall_test, labels), tf.float32))

    summ_loss_combined = tf.summary.scalar('loss', loss)
    summ_loss_hall_l2 = tf.summary.scalar('loss_hall_l2',
                                          loss_hall_rect_static)
    summ_loss_hard_hall = tf.summary.scalar('loss_hard_hall_class',
                                            loss_hard_hall_class)
    summ_loss_soft_hall = tf.summary.scalar('loss_soft_hall_class',
                                            loss_soft_hall_class)
    summ_acc_train_hall = tf.summary.scalar('acc_train_hall', acc_hall_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summ_train = tf.summary.merge([
        summ_acc_train_hall, summ_acc_train_depth, summ_loss_hall_l2,
        summ_loss_hard_hall, summ_loss_combined, summ_loss_soft_hall
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_val_hall = tf.summary.scalar('acc_val_hall', accuracy_value_)
    summ_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summ_acc_test_hall = tf.summary.scalar('acc_test_hall', accuracy_value_)
    summ_acc_test_depth = tf.summary.scalar('acc_test_depth', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        utils.update_details_file(f_log,
                                  lambda1=lambda1,
                                  lambda2=lambda2,
                                  lambda3=lambda3)

        sess.run(tf.global_variables_initializer())
        restorers.restore_weights_s3(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_hall = summ_acc_val_hall
                step_summ_depth = summ_acc_val_depth
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_hall = summ_acc_test_hall
                step_summ_depth = summ_acc_test_depth
            try:
                accum_correct_depth = accum_correct_hall = 0
                while True:
                    n_correct_hall1, n_correct_depth1 = sess.run(
                        [n_correct_hall, n_correct_depth],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_depth += n_correct_depth1
                    accum_correct_hall += n_correct_hall1
            except tf.errors.OutOfRangeError:
                acc_hall = accum_correct_hall / step_samples
                acc_depth = accum_correct_depth / step_samples
                summ_hall_acc = sess.run(step_summ_hall,
                                         feed_dict={accuracy_value_: acc_hall})
                summary_writer.add_summary(summ_hall_acc, value_step)
                summ_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(summ_depth_acc, value_step)
                utils.double_log(f_log, 'Hall acc = %s \n' % str(acc_hall))
                utils.double_log(f_log, 'Depth acc = %s \n' % str(acc_depth))
                return acc_hall

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:
                        _, summary = sess.run([minimizing, summ_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run([minimizing],
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()