Exemplo n.º 1
0
 def val_test(value_step, mode='val'):
     if mode == 'val':
         if not val_filenames:
             return -1
         utils.double_log(f_log, "eval validation set \n")
         sess.run(val_iterator.initializer)
         step_handle = val_handle
         step_samples = len(val_filenames)
         step_summ_rgb = summary_acc_val_rgb
         step_summ_depth = summary_acc_val_depth
         step_summ_combined = summary_acc_val_combined
     elif mode == 'test':
         utils.double_log(f_log, "eval test set \n")
         sess.run(test_iterator.initializer)
         step_handle = test_handle
         step_samples = len(test_filenames)
         step_summ_rgb = summary_acc_test_rgb
         step_summ_depth = summary_acc_test_depth
         step_summ_combined = summary_acc_test_combined
     try:
         accum_correct_rgb = accum_correct_depth = accum_correct_combined_val = 0
         while True:
             n_correct_rgb1, n_correct_depth1, n_correct_combined1 = sess.run(
                 [n_correct_rgb, n_correct_depth, n_correct_combined],
                 feed_dict={
                     handle: step_handle,
                     is_training: False
                 })
             accum_correct_rgb += n_correct_rgb1
             accum_correct_depth += n_correct_depth1
             accum_correct_combined_val += n_correct_combined1
     except tf.errors.OutOfRangeError:
         acc_rgb = accum_correct_rgb / step_samples
         acc_depth = accum_correct_depth / step_samples
         acc_combined = accum_correct_combined_val / step_samples
         sum_rgb_acc = sess.run(step_summ_rgb,
                                feed_dict={accuracy_value_: acc_rgb})
         summary_writer.add_summary(sum_rgb_acc, value_step)
         sum_depth_acc = sess.run(
             step_summ_depth, feed_dict={accuracy_value_: acc_depth})
         summary_writer.add_summary(sum_depth_acc, value_step)
         sum_combined_acc = sess.run(
             step_summ_combined,
             feed_dict={accuracy_value_: acc_combined})
         summary_writer.add_summary(sum_combined_acc, value_step)
         utils.double_log(f_log,
                          'Depth accuracy = %s \n' % str(acc_depth))
         utils.double_log(f_log, 'RGB accuracy = %s \n' % str(acc_rgb))
         utils.double_log(
             f_log, 'combined accuracy = %s \n' % str(acc_combined))
         return acc_combined
Exemplo n.º 2
0
        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ = summ_acc_val
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ = summ_acc_test

            try:
                accum_correct = 0
                while True:
                    n_correct_val = sess.run(n_correct,
                                             feed_dict={
                                                 handle: step_handle,
                                                 is_training: False
                                             })
                    accum_correct += n_correct_val
            except tf.errors.OutOfRangeError:
                step_acc = accum_correct / step_samples
                summary_acc = sess.run(step_summ,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log, 'Accuracy = %s \n' % str(step_acc))
            return step_acc
Exemplo n.º 3
0
 def val_test(value_step, mode='val'):
     if mode == 'val':
         if not val_filenames:
             return -1
         utils.double_log(f_log, "eval validation set \n")
         sess.run(val_iterator.initializer)
         step_handle = val_handle
         step_samples = len(val_filenames)
         step_summ_hall = summ_acc_val_hall
         step_summ_depth = summ_acc_val_depth
     elif mode == 'test':
         utils.double_log(f_log, "eval test set \n")
         sess.run(test_iterator.initializer)
         step_handle = test_handle
         step_samples = len(test_filenames)
         step_summ_hall = summ_acc_test_hall
         step_summ_depth = summ_acc_test_depth
     try:
         accum_correct_depth = accum_correct_hall = 0
         while True:
             n_correct_hall1, n_correct_depth1 = sess.run(
                 [n_correct_hall, n_correct_depth],
                 feed_dict={
                     handle: step_handle,
                     is_training: False
                 })
             accum_correct_depth += n_correct_depth1
             accum_correct_hall += n_correct_hall1
     except tf.errors.OutOfRangeError:
         acc_hall = accum_correct_hall / step_samples
         acc_depth = accum_correct_depth / step_samples
         summ_hall_acc = sess.run(step_summ_hall,
                                  feed_dict={accuracy_value_: acc_hall})
         summary_writer.add_summary(summ_hall_acc, value_step)
         summ_depth_acc = sess.run(
             step_summ_depth, feed_dict={accuracy_value_: acc_depth})
         summary_writer.add_summary(summ_depth_acc, value_step)
         utils.double_log(f_log, 'Hall acc = %s \n' % str(acc_hall))
         utils.double_log(f_log, 'Depth acc = %s \n' % str(acc_depth))
         return acc_hall
Exemplo n.º 4
0
def train(exp_id, files, args):
    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(lambda x: parsers._parse_fun_one_mod(
            x, is_training=True, modality=args.modality))
        seed = tf.placeholder(tf.int64, shape=())  # =epoch
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(lambda x: parsers._parse_fun_one_mod(
                x, is_training=False, modality=args.modality))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(lambda x: parsers._parse_fun_one_mod(
            x, is_training=False, modality=args.modality))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_stacked = next_element[0]  # [batch_sz, time_bottleneck, h,w,c]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[1] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[1]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        stack_shape = tf.shape(images_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        batch_images = tf.reshape(
            images_stacked, [stack_shape[0] * stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(
                resnet_v1.resnet_arg_scope(batch_norm_decay=args.bn_decay)):
            is_training = tf.placeholder(tf.bool, name="is_training")
            nr_frames = parsers.time_bottleneck
            scope = 'resnet_v1_50'

            net_out, net_endpoints = resnet_v1.resnet_one_stream_main(
                batch_images,
                nr_frames,
                num_classes=n_classes,
                scope=scope,
                gpu_id=args.gpu0,
                is_training=is_training)

            # predictions for each video are the avg of frames' predictions
            # TRAIN ###############################
            net_train = tf.reshape(net_out, [-1, nr_frames, n_classes])
            net_train = tf.reduce_mean(net_train, axis=1)
            # TEST ###############################
            net_test = tf.reshape(net_out, [-1, nr_frames, n_classes])
            net_test = tf.reduce_mean(net_test, axis=1)

            # loss ##########################################################
            loss = slim.losses.softmax_cross_entropy(net_train, labels)
            # optimizers ######################################################
            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            minimizing = slim.learning.create_train_op(loss, optimizer)

            acc_train = utils.accuracy(net_train, labels)
            n_correct = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_test, labels), tf.float32))

    summ_loss = tf.summary.scalar('loss', loss)
    summ_acc_train = tf.summary.scalar('acc_train', acc_train)
    summ_train = tf.summary.merge([summ_acc_train, summ_loss])
    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_test = tf.summary.scalar('acc_test', accuracy_value_)
    summ_acc_val = tf.summary.scalar('acc_val', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.just_eval:
            restorers.restore_weights_s1_continue(sess, args.ckpt,
                                                  args.modality)
        else:
            restorers.restore_weights_s1(sess,
                                         files['imagenet_checkpoint_file'])

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ = summ_acc_val
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ = summ_acc_test

            try:
                accum_correct = 0
                while True:
                    n_correct_val = sess.run(n_correct,
                                             feed_dict={
                                                 handle: step_handle,
                                                 is_training: False
                                             })
                    accum_correct += n_correct_val
            except tf.errors.OutOfRangeError:
                step_acc = accum_correct / step_samples
                summary_acc = sess.run(step_summ,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log, 'Accuracy = %s \n' % str(step_acc))
            return step_acc

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:  # get summaries
                        _, summary = sess.run([minimizing, summ_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run(minimizing,
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
Exemplo n.º 5
0
        def val_test(value_step, mode='val'):
            if mode == 'val' and val_fnames:
                utils.double_log(f_log, "eval val set \n")
                sess.run(val_it.initializer)
                step_handle = val_handle
                step_samples = len(val_fnames) / 10
                step_summ_rgb = summ_acc_val_rgb
                step_summ_depth = summ_acc_val_depth
                step_summ_flow = summ_acc_val_flow
                step_summ_sum = summ_acc_val_sum
                step_summ_oracle = summ_acc_val_oracle
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_it.initializer)
                step_handle = test_handle
                step_samples = len(test_fnames) / 10
                step_summ_rgb = summ_acc_test_rgb
                step_summ_depth = summ_acc_test_depth
                step_summ_flow = summ_acc_test_flow
                step_summ_sum = summ_acc_test_sum
                step_summ_oracle = summ_acc_test_oracle
            else:
                return -1
            try:
                dict_sum = {}
                dict_rgb = {}
                dict_of = {}
                dict_depth = {}
                hist_acc_per_class = np.zeros([5, n_classes])
                n_samples_per_class = np.zeros(n_classes)
                while True:
                    logits_sum, logits_r, logits_d, logits_f, video_id_val = sess.run(
                        [
                            logits_fused, logits_rgb, logits_depth,
                            logits_oflow, batch_video_id
                        ],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    for i in range(len(video_id_val)):
                        v = video_id_val[i]
                        if v in dict_sum:
                            dict_sum[v] = dict_sum[v] + logits_sum[i]
                        else:
                            dict_sum[v] = logits_sum[i]
                            if 'ntu' in args.dset:
                                this_class = int(v[-3:]) - 1
                            elif 'nwucla' in args.dset:
                                this_class = utils.get_nwucla_class(
                                    int(v[8:10])) - 1
                            elif 'uwa3dii' in args.dset:
                                this_class = int(v[1:3]) - 1
                            n_samples_per_class[this_class] += 1
                        if v in dict_rgb:
                            dict_rgb[v] = dict_rgb[v] + logits_r[i]
                        else:
                            dict_rgb[v] = logits_r[i]
                        if v in dict_of:
                            dict_of[v] = dict_of[v] + logits_f[i]
                        else:
                            dict_of[v] = logits_f[i]
                        if v in dict_depth:
                            dict_depth[v] = dict_depth[v] + logits_d[i]
                        else:
                            dict_depth[v] = logits_d[i]
            except tf.errors.OutOfRangeError:
                accum_correct_sum = 0
                accum_correct_r = 0
                accum_correct_d = 0
                accum_correct_f = 0
                accum_correct_oracle = 0
                for key in dict_sum:
                    dict_sum[key] = dict_sum[key] / 10
                    dict_rgb[key] = dict_rgb[key] / 10
                    dict_depth[key] = dict_depth[key] / 10
                    dict_of[key] = dict_of[key] / 10
                    pred_sum = np.argmax(dict_sum[key])
                    pred_r = np.argmax(dict_rgb[key])
                    pred_d = np.argmax(dict_depth[key])
                    pred_f = np.argmax(dict_of[key])
                    if 'ntu' in args.dset:
                        lab = int(key[-3:]) - 1
                    elif 'uwa3d' in args.dset:
                        lab = int(key[1:3]) - 1
                    elif 'nwucla' in args.dset:
                        lab = utils.get_nwucla_class(int(key[8:10])) - 1
                    if lab == pred_sum:
                        accum_correct_sum += 1
                        hist_acc_per_class[0, lab] += 1
                    if lab == pred_r:
                        hist_acc_per_class[2, lab] += 1
                        accum_correct_r += 1
                    if lab == pred_d:
                        hist_acc_per_class[3, lab] += 1
                        accum_correct_d += 1
                    if lab == pred_f:
                        hist_acc_per_class[4, lab] += 1
                        accum_correct_f += 1
                    if lab in [pred_f, pred_r, pred_d]:
                        accum_correct_oracle += 1
                        hist_acc_per_class[1, lab] += 1

                for i in range(n_classes):
                    value = n_samples_per_class[i]
                    hist_acc_per_class[:, i] = hist_acc_per_class[:, i] / value

                utils.double_log(
                    f_log,
                    'accuracy per class: 0sum 1oracle 2rgb 3depth 4flow \n')
                utils.double_log(
                    f_log, '\n' + np.array2string(hist_acc_per_class) + '\n')

                step_acc = accum_correct_r / step_samples
                utils.double_log(f_log, 'rgb Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_rgb,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_d / step_samples
                utils.double_log(f_log,
                                 'depth Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_depth,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_f / step_samples
                utils.double_log(f_log,
                                 'flow Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_flow,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_sum / step_samples
                summary_acc = sess.run(step_summ_sum,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log, 'sum Accuracy = %s \n' % str(step_acc))

                step_acc = accum_correct_oracle / step_samples
                summary_acc = sess.run(step_summ_oracle,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log,
                                 'oracle Accuracy = %s \n' % str(step_acc))
            return step_acc
Exemplo n.º 6
0
def train(exp_id, train_fnames, val_fnames, test_fnames, n_classes):
    log_path, ckpt_path = utils.get_log_ckpt_path(args.dryrun)
    log_path = os.path.join(log_path, args.dset, exp_id)
    ckpt_path = os.path.join(ckpt_path, args.dset, exp_id)

    # dataset ######################################################
    with tf.device('/cpu:0'):
        dset_train = tf.data.TFRecordDataset(train_fnames,
                                             compression_type="GZIP")
        seed = tf.placeholder(tf.int64, shape=())  # =epoch
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.map(
            lambda x: parsers._parse_mult_frame_s2(x, rescale=False),
            num_parallel_calls=8)
        dset_train = dset_train.batch(args.batch_sz, drop_remainder=True)
        dset_train = dset_train.prefetch(buffer_size=10)

        dset_val = tf.data.TFRecordDataset(val_fnames, compression_type="GZIP")
        dset_val = dset_val.map(
            lambda x: parsers._parse_mult_frame_test_allmods(x, rescale=False),
            num_parallel_calls=8)
        dset_val = dset_val.batch(args.batch_sz)
        dset_val = dset_val.prefetch(buffer_size=10)

        dset_test = tf.data.TFRecordDataset(test_fnames,
                                            compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_mult_frame_test_allmods(x, rescale=False),
            num_parallel_calls=8)
        dset_test = dset_test.batch(args.batch_sz)
        dset_test = dset_test.prefetch(buffer_size=10)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_it = dset_train.make_initializable_iterator()
        val_it = dset_val.make_initializable_iterator()
        test_it = dset_test.make_initializable_iterator()
        next_element = iterator.get_next()
        batch_rgb = next_element[0]
        batch_depth = next_element[1]
        batch_oflow = next_element[2]
        batch_labels_raw = next_element[3]  # video labels
        batch_labels = tf.one_hot(batch_labels_raw, n_classes)
        batch_video_id = next_element[4]

    # tf Graph input ##############################################
    is_training = tf.placeholder(tf.bool, name="is_training")
    # with tf.device('/device:GPU:0'):
    net_oflow = resnet.Model(
        resnet_size=18,
        bottleneck=False,  # resnet original bottleneck, not ours
        num_classes=n_classes,
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        block_sizes=[2, 2, 2, 2],
        block_strides=[1, 2, 2, 2],
        temporal_strides=[1, 2, 2, 2],
        resnet_version=1,
        data_format='channels_last',
        dtype=tf.float32,
        n_frames=args.n_frames,
        name='resnet_oflow')

    # with tf.device('/device:GPU:1'):
    net_rgb = resnet.Model(
        resnet_size=18,
        bottleneck=False,  # resnet original bottleneck, not ours
        num_classes=n_classes,
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        block_sizes=[2, 2, 2, 2],
        block_strides=[1, 2, 2, 2],
        temporal_strides=[1, 2, 2, 2],
        resnet_version=1,
        data_format='channels_last',
        dtype=tf.float32,
        n_frames=args.n_frames,
        name='resnet_rgb')

    # with tf.device('/device:GPU:2'):
    net_depth = resnet.Model(
        resnet_size=18,
        bottleneck=False,  # resnet original bottleneck, not ours
        num_classes=n_classes,
        num_filters=64,
        kernel_size=7,
        conv_stride=2,
        first_pool_size=3,
        first_pool_stride=2,
        block_sizes=[2, 2, 2, 2],
        block_strides=[1, 2, 2, 2],
        temporal_strides=[1, 2, 2, 2],
        resnet_version=1,
        data_format='channels_last',
        dtype=tf.float32,
        n_frames=args.n_frames,
        name='resnet_depth')

    logits_oflow, reps_oflow = net_oflow(batch_oflow,
                                         training=is_training,
                                         output_rep=True)
    logits_depth, reps_depth = net_depth(batch_depth,
                                         training=is_training,
                                         output_rep=True)
    logits_rgb, reps_rgb = net_rgb(batch_rgb,
                                   training=is_training,
                                   output_rep=True)
    logits_fused = logits_rgb + logits_depth + logits_oflow

    ############################################################
    ############################################################
    # from https://github.com/chhwang/cmcl/blob/master/src/model.py
    logits_list = [logits_rgb, logits_depth, logits_oflow]
    closs_list = [
        tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=batch_labels_raw)
        for logits in logits_list
    ]
    # closs_list is a list with three elements, each of which is shape (batch_sz,)
    # min_index.shape=(batch_sz,) : indicates the stream with the min loss
    _, min_index = tf.nn.top_k(-tf.transpose(closs_list), 1)
    min_index = tf.transpose(min_index)

    ############################################################
    # SOFT
    # weight of loser distillation loss
    placeh_a_distill_l = tf.placeholder(tf.float32)
    # weight of loser ground truth loss
    placeh_a_loser_gt = tf.placeholder(tf.float32)
    placeh_temp = tf.placeholder(tf.float32)

    soft_softmax = [
        tf.nn.softmax(logits / placeh_temp) for logits in logits_list
    ]
    soft_winners = []
    for i in range(args.batch_sz):
        soft_winners.append(tf.gather_nd(soft_softmax, [min_index[0][i], i]))

    soft_loss_win = 0
    soft_loss_loser = 0
    for m in range(3):
        total_condition = tf.constant([False] * args.batch_sz, dtype=tf.bool)
        topk = 0
        # true if this modality m is the winner
        condition = tf.equal(min_index[topk], m)
        total_condition = tf.logical_or(
            total_condition, condition)  # used for when topWinners> 0
        new_labels2 = tf.where(condition, batch_labels,
                               soft_winners)  # true, false
        new_logits = tf.where(condition, logits_list[m],
                              logits_list[m] / placeh_temp)

        loss_win = \
            tf.where(total_condition,
                     tf.stack([1.] * args.batch_sz),
                     tf.stack([placeh_a_loser_gt] * args.batch_sz)) * \
            tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=tf.stop_gradient(batch_labels), logits=logits_list[m])
        soft_loss_win += tf.reduce_mean(loss_win)
        # placeh_a_loser_gt is the weitght for ground_truth labels for losers.

        loss_losers = \
            tf.where(total_condition,
                     tf.stack([0.] * args.batch_sz),
                     tf.stack([placeh_a_distill_l] * args.batch_sz)) * \
            tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=tf.stop_gradient(new_labels2), logits=new_logits)
        soft_loss_loser += tf.reduce_mean(loss_losers)
        # placeh_a_distill_l is the distillation weight to the losers

    # scaling loss due to distillation
    soft_loss_loser = soft_loss_loser * placeh_temp * placeh_temp
    loss = soft_loss_loser + soft_loss_win

    ###########################################################

    def exclude_batch_norm(name):
        # If no loss_filter_fn is passed, assume we want the default behavior,
        # which is that batch_normalization variables are excluded from loss.
        return 'batch_normalization' not in name and 'bias' not in name

    loss_filter_fn = exclude_batch_norm

    weight_decay = 1e-4
    l2_loss_of = weight_decay * tf.add_n([
        tf.nn.l2_loss(tf.cast(v, tf.float32))
        for v in tf.trainable_variables()
        if loss_filter_fn(v.name) and 'flow' in v.name
    ])
    l2_loss_depth = weight_decay * tf.add_n([
        tf.nn.l2_loss(tf.cast(v, tf.float32))
        for v in tf.trainable_variables()
        if loss_filter_fn(v.name) and 'depth' in v.name
    ])
    l2_loss_rgb = weight_decay * tf.add_n([
        tf.nn.l2_loss(tf.cast(v, tf.float32))
        for v in tf.trainable_variables()
        if loss_filter_fn(v.name) and 'rgb' in v.name
    ])

    loss += l2_loss_of + l2_loss_rgb + l2_loss_depth

    global_step = tf.train.get_or_create_global_step()
    if args.optimizer == 'Adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        learning_rate = args.learning_rate

    if args.optimizer == 'Momentum':
        # base_lr default is .0128
        lr_fn = utils.learning_rate_with_decay(
            batch_size=args.batch_sz,
            batch_denom=args.batch_sz,
            num_images=len(train_fnames),
            boundary_epochs=[100, 150, 180, 200],
            decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
            warmup=True,
            base_lr=.00128)
        learning_rate = lr_fn(global_step)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9)
    if args.optimizer == 'Momentum-finetune':
        # base_lr default is .0128
        lr_fn = utils.learning_rate_with_decay(
            batch_size=args.batch_sz,
            batch_denom=args.batch_sz,
            num_images=len(train_fnames),
            boundary_epochs=[100, 150, 180, 200],
            decay_rates=[1, 0.1, 0.01, 0.001, 1e-4],
            warmup=True,
            base_lr=.0000128)
        learning_rate = lr_fn(global_step)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9)

    grad_vars = optimizer.compute_gradients(loss)
    minimize_op = optimizer.apply_gradients(grad_vars, global_step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = tf.group(minimize_op, update_ops)

    acc_train = utils.accuracy(logits_fused, batch_labels)
    acc_train_rgb = utils.accuracy(logits_rgb, batch_labels)
    acc_train_depth = utils.accuracy(logits_depth, batch_labels)
    acc_train_flow = utils.accuracy(logits_oflow, batch_labels)

    ########### SUMMARIES ######################################
    for gradient, variable in grad_vars:
        if '/dense/' in variable.name or 'conv3d_1/' in variable.name:
            tf.summary.scalar("norm/gradients/" + variable.name,
                              tf.norm(gradient))
            tf.summary.scalar("norm/variables/" + variable.name,
                              tf.norm(variable))
            tf.summary.histogram("values/variables/" + variable.name,
                                 tf.reshape(variable, [-1]))
            tf.summary.histogram("values/gradients/" + variable.name,
                                 tf.reshape(gradient, [-1]))

    # reduce_mean to get the mean for the batch
    tf.summary.scalar('norm/logits_rgb',
                      tf.reduce_mean(tf.map_fn(tf.norm, logits_rgb)))
    tf.summary.scalar('norm/logits_depth',
                      tf.reduce_mean(tf.map_fn(tf.norm, logits_depth)))
    tf.summary.scalar('norm/logits_of',
                      tf.reduce_mean(tf.map_fn(tf.norm, logits_oflow)))
    summaries_norms_logits = tf.summary.merge_all(scope='norm/')
    tf.summary.histogram('values/logits_rgb', tf.reshape(logits_rgb, [-1]))
    tf.summary.histogram('values/logits_depth', tf.reshape(logits_depth, [-1]))
    tf.summary.histogram('values/logits_of', tf.reshape(logits_oflow, [-1]))
    summaries_vars = tf.summary.merge_all(scope='values/')
    # tf.summary.scalar('loss/xentropy', loss_total)
    # tf.summary.scalar('loss/xentropy_rgb', cross_entropy_rgb)
    # tf.summary.scalar('loss/xentropy_depth', cross_entropy_depth)
    # tf.summary.scalar('loss/xentropy_of', cross_entropy_of)
    # tf.summary.scalar('loss/l2', l2_loss)
    tf.summary.scalar('loss/total', loss)
    summaries_losses = tf.summary.merge_all(scope='loss/')
    tf.summary.scalar('acc/train', acc_train)
    tf.summary.scalar('acc/train_rgb', acc_train_rgb)
    tf.summary.scalar('acc/train_depth', acc_train_depth)
    tf.summary.scalar('acc/train_flow', acc_train_flow)
    summaries_acc_train = tf.summary.merge_all(scope='acc/')
    summ_lr = tf.summary.scalar('learning_rate', learning_rate)
    summ_train = tf.summary.merge([
        summaries_norms_logits, summaries_vars, summ_lr, summaries_acc_train,
        summaries_losses
    ])
    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_test_oracle = tf.summary.scalar('acc_test_oracle',
                                             accuracy_value_)
    summ_acc_test_sum = tf.summary.scalar('acc_test_sum', accuracy_value_)
    summ_acc_test_rgb = tf.summary.scalar('acc_test_rgb', accuracy_value_)
    summ_acc_test_flow = tf.summary.scalar('acc_test_flow', accuracy_value_)
    summ_acc_test_depth = tf.summary.scalar('acc_test_depth', accuracy_value_)
    summ_acc_val_oracle = tf.summary.scalar('acc_val_oracle', accuracy_value_)
    summ_acc_val_sum = tf.summary.scalar('acc_val_sum', accuracy_value_)
    summ_acc_val_rgb = tf.summary.scalar('acc_val_rgb', accuracy_value_)
    summ_acc_val_flow = tf.summary.scalar('acc_val_flow', accuracy_value_)
    summ_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    #################################################

    test_saver = tf.train.Saver(max_to_keep=3)
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True
    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_it.string_handle())
        val_handle = sess.run(val_it.string_handle())
        test_handle = sess.run(test_it.string_handle())
        sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter(log_path, sess.graph)
        f_log = utils.init_f_log(log_path, exp_id, sys.argv[:])

        def val_test(value_step, mode='val'):
            if mode == 'val' and val_fnames:
                utils.double_log(f_log, "eval val set \n")
                sess.run(val_it.initializer)
                step_handle = val_handle
                step_samples = len(val_fnames) / 10
                step_summ_rgb = summ_acc_val_rgb
                step_summ_depth = summ_acc_val_depth
                step_summ_flow = summ_acc_val_flow
                step_summ_sum = summ_acc_val_sum
                step_summ_oracle = summ_acc_val_oracle
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_it.initializer)
                step_handle = test_handle
                step_samples = len(test_fnames) / 10
                step_summ_rgb = summ_acc_test_rgb
                step_summ_depth = summ_acc_test_depth
                step_summ_flow = summ_acc_test_flow
                step_summ_sum = summ_acc_test_sum
                step_summ_oracle = summ_acc_test_oracle
            else:
                return -1
            try:
                dict_sum = {}
                dict_rgb = {}
                dict_of = {}
                dict_depth = {}
                hist_acc_per_class = np.zeros([5, n_classes])
                n_samples_per_class = np.zeros(n_classes)
                while True:
                    logits_sum, logits_r, logits_d, logits_f, video_id_val = sess.run(
                        [
                            logits_fused, logits_rgb, logits_depth,
                            logits_oflow, batch_video_id
                        ],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    for i in range(len(video_id_val)):
                        v = video_id_val[i]
                        if v in dict_sum:
                            dict_sum[v] = dict_sum[v] + logits_sum[i]
                        else:
                            dict_sum[v] = logits_sum[i]
                            if 'ntu' in args.dset:
                                this_class = int(v[-3:]) - 1
                            elif 'nwucla' in args.dset:
                                this_class = utils.get_nwucla_class(
                                    int(v[8:10])) - 1
                            elif 'uwa3dii' in args.dset:
                                this_class = int(v[1:3]) - 1
                            n_samples_per_class[this_class] += 1
                        if v in dict_rgb:
                            dict_rgb[v] = dict_rgb[v] + logits_r[i]
                        else:
                            dict_rgb[v] = logits_r[i]
                        if v in dict_of:
                            dict_of[v] = dict_of[v] + logits_f[i]
                        else:
                            dict_of[v] = logits_f[i]
                        if v in dict_depth:
                            dict_depth[v] = dict_depth[v] + logits_d[i]
                        else:
                            dict_depth[v] = logits_d[i]
            except tf.errors.OutOfRangeError:
                accum_correct_sum = 0
                accum_correct_r = 0
                accum_correct_d = 0
                accum_correct_f = 0
                accum_correct_oracle = 0
                for key in dict_sum:
                    dict_sum[key] = dict_sum[key] / 10
                    dict_rgb[key] = dict_rgb[key] / 10
                    dict_depth[key] = dict_depth[key] / 10
                    dict_of[key] = dict_of[key] / 10
                    pred_sum = np.argmax(dict_sum[key])
                    pred_r = np.argmax(dict_rgb[key])
                    pred_d = np.argmax(dict_depth[key])
                    pred_f = np.argmax(dict_of[key])
                    if 'ntu' in args.dset:
                        lab = int(key[-3:]) - 1
                    elif 'uwa3d' in args.dset:
                        lab = int(key[1:3]) - 1
                    elif 'nwucla' in args.dset:
                        lab = utils.get_nwucla_class(int(key[8:10])) - 1
                    if lab == pred_sum:
                        accum_correct_sum += 1
                        hist_acc_per_class[0, lab] += 1
                    if lab == pred_r:
                        hist_acc_per_class[2, lab] += 1
                        accum_correct_r += 1
                    if lab == pred_d:
                        hist_acc_per_class[3, lab] += 1
                        accum_correct_d += 1
                    if lab == pred_f:
                        hist_acc_per_class[4, lab] += 1
                        accum_correct_f += 1
                    if lab in [pred_f, pred_r, pred_d]:
                        accum_correct_oracle += 1
                        hist_acc_per_class[1, lab] += 1

                for i in range(n_classes):
                    value = n_samples_per_class[i]
                    hist_acc_per_class[:, i] = hist_acc_per_class[:, i] / value

                utils.double_log(
                    f_log,
                    'accuracy per class: 0sum 1oracle 2rgb 3depth 4flow \n')
                utils.double_log(
                    f_log, '\n' + np.array2string(hist_acc_per_class) + '\n')

                step_acc = accum_correct_r / step_samples
                utils.double_log(f_log, 'rgb Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_rgb,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_d / step_samples
                utils.double_log(f_log,
                                 'depth Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_depth,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_f / step_samples
                utils.double_log(f_log,
                                 'flow Accuracy = %s \n' % str(step_acc))
                summary_acc = sess.run(step_summ_flow,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)

                step_acc = accum_correct_sum / step_samples
                summary_acc = sess.run(step_summ_sum,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log, 'sum Accuracy = %s \n' % str(step_acc))

                step_acc = accum_correct_oracle / step_samples
                summary_acc = sess.run(step_summ_oracle,
                                       feed_dict={accuracy_value_: step_acc})
                summary_writer.add_summary(summary_acc, value_step)
                utils.double_log(f_log,
                                 'oracle Accuracy = %s \n' % str(step_acc))
            return step_acc

        period_val_acc = 5  # periodicity for validation acc
        patience = 5  # how many (period_val_acc * patience) epochs til give up
        validation_accuracies = [0] * patience
        n_step = 0
        best_step = 0

        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_it.initializer, feed_dict={seed: epoch})
            accum_min_index = np.zeros(3)
            hist_classes = np.zeros([3, n_classes])
            a_distill_l = args.a_distill_l
            a_loser_gt = args.a_loser_gt
            temp = args.temp
            try:
                while True:
                    if n_step % args.step_summ == 0:  # get summaries
                        batch_labels_raw_val, min_index_val, _, summary = sess.run(
                            [
                                batch_labels_raw, min_index, train_op,
                                summ_train
                            ],
                            feed_dict={
                                handle: train_handle,
                                is_training: True,
                                placeh_a_distill_l: a_distill_l,
                                placeh_a_loser_gt: a_loser_gt,
                                placeh_temp: temp
                            })
                        min_index_val = np.squeeze(min_index_val)
                        min_index_val = [int(x) for x in min_index_val]
                        for i, j in enumerate(min_index_val):
                            accum_min_index[j] += 1
                            hist_classes[j][batch_labels_raw_val[i]] += 1
                        summary_writer.add_summary(summary, n_step)
                    else:
                        batch_labels_raw_val, min_index_val, _ = sess.run(
                            [batch_labels_raw, min_index, train_op],
                            feed_dict={
                                handle: train_handle,
                                is_training: True,
                                placeh_a_distill_l: a_distill_l,
                                placeh_a_loser_gt: a_loser_gt,
                                placeh_temp: temp
                            })
                        min_index_val = np.squeeze(min_index_val)
                        min_index_val = [int(x) for x in min_index_val]
                        for i, j in enumerate(min_index_val):
                            accum_min_index[j] += 1
                            hist_classes[j][batch_labels_raw_val[i]] += 1
                    n_step = n_step + 1
            except tf.errors.OutOfRangeError:
                utils.double_log(f_log, 'nets: rgb depth oflow \n')
                utils.double_log(
                    f_log,
                    'total number of examples that each net saw this epoch \n')
                utils.double_log(f_log,
                                 np.array2string(accum_min_index) + '\n')
                utils.double_log(
                    f_log,
                    'total number of examples that each net saw this epoch, per class\n'
                )
                utils.double_log(f_log,
                                 '\n' + np.array2string(hist_classes) + '\n')
                if epoch % period_val_acc == 0:
                    validation_accuracy = val_test(n_step, mode='val')
                    validation_accuracies.append(validation_accuracy)
                    if validation_accuracy >= np.max(validation_accuracies):
                        best_step = n_step
                        test_saver.save(sess,
                                        os.path.join(ckpt_path,
                                                     'test/model.ckpt'),
                                        global_step=n_step)
                        continue
                    elif not any(x < validation_accuracy
                                 for x in validation_accuracies[:-patience]
                                 ) and epoch > 10:
                        break
                    else:
                        continue
                else:
                    continue

        utils.double_log(f_log, "Optimization Finished!\n")
        test_saver.save(sess,
                        os.path.join(ckpt_path, 'test/model.ckpt'),
                        global_step=n_step)
        val_test(n_step + 1, mode='test')
        restorers.restore_all_weights(
            sess, os.path.join(ckpt_path, 'test/model.ckpt-' + str(best_step)))
        val_test(n_step + 2, mode='test')
        f_log.close()
        summary_writer.close()
Exemplo n.º 7
0
def train(exp_id, files, args):
    hallucination_layer_true = 'resnet_v1_50_of/block4'
    hallucination_layer_hall = 'resnet_v1_50_hall/block4'

    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
            labels_per_frame = next_element[3]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)
        labels_per_frame = tf.reshape(labels_per_frame, [-1])
        labels_per_frame = tf.one_hot(labels_per_frame, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth = resnet_v1.resnet_one_stream_main(
                images_depth,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_depth',
                gpu_id='/gpu:0',
                is_training=False,
                bottleneck=True)

            net_hall_out, net_hall_endpoints = resnet_v1.resnet_one_stream_main(
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_hall',
                gpu_id='/gpu:1',
                is_training=is_training,
                bottleneck=True)

            nr_batch_frames = tf.shape(labels_per_frame)
            nr_batch_vid = tf.shape(labels)

            temporal_order = utils.get_temporal_order_onehot(
                nr_batch_vid[0], parsers.time_bottleneck)
            temporal_order = tf.expand_dims(temporal_order, axis=1)
            temporal_order = tf.expand_dims(temporal_order, axis=1)

            feat_depth = endpoints_depth['last_pool']
            logits_real = resnet_v1.feature_discriminator(feat_depth,
                                                          temporal_order,
                                                          n_classes=n_classes)

            feat_hall = net_hall_endpoints['last_pool']
            logits_fake = resnet_v1.feature_discriminator(feat_hall,
                                                          temporal_order,
                                                          n_classes=n_classes,
                                                          reuse=True)

            logits_real = tf.squeeze(logits_real, [1, 2])
            logits_fake = tf.squeeze(logits_fake, [1, 2])

            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_hall_train = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_train = tf.reduce_mean(net_hall_train, axis=1)

            # TEST ###############################
            net_hall_test = tf.reshape(net_hall_out,
                                       [-1, utils.time_bottleneck, n_classes])
            net_hall_test = tf.reduce_mean(net_hall_test, axis=1)
            net_depth_test = tf.reshape(net_depth_out,
                                        [-1, utils.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)

            # losses ##########################################################
            d_target_dist_real = tf.concat(
                axis=-1,
                values=[
                    tf.zeros([nr_batch_frames[0], 1], tf.float32),
                    tf.cast(labels_per_frame, tf.float32)
                ])
            d_loss_real = slim.losses.softmax_cross_entropy(
                logits_real, d_target_dist_real)
            d_target_dist_fake = tf.concat(
                axis=-1,
                values=[
                    tf.ones([nr_batch_frames[0], 1], tf.float32),
                    tf.zeros([nr_batch_frames[0], n_classes], tf.float32)
                ])
            d_loss_fake = slim.losses.softmax_cross_entropy(
                logits_fake, d_target_dist_fake)
            d_loss = .5 * (d_loss_real + d_loss_fake)

            g_target_dist_fake = tf.concat(
                axis=-1,
                values=[
                    tf.zeros([nr_batch_frames[0], 1], tf.float32),
                    tf.cast(labels_per_frame, tf.float32)
                ])
            g_loss = slim.losses.softmax_cross_entropy(logits_fake,
                                                       g_target_dist_fake)

            loss_hall_rect_static = utils.loss_hall_rect(
                endpoints_depth[hallucination_layer_true],
                net_hall_endpoints[hallucination_layer_hall])
            loss_hall_rect_static2 = utils.loss_hall_rect(
                endpoints_depth['last_pool'], net_hall_endpoints['last_pool'])

            d_optimizer = tf.train.AdamOptimizer(args.learning_rate)
            g_optimizer = tf.train.AdamOptimizer(args.learning_rate)

            t_vars = tf.trainable_variables()
            # freezing depth
            depth_vars = [x for x in t_vars if 'resnet_v1_50_of' in x.name]
            to_remove = depth_vars
            train_vars = [x for x in t_vars if x not in to_remove]

            train_vars_d = [x for x in train_vars if 'disc_e' in x.name]
            minimizing_d = slim.learning.create_train_op(
                d_loss, d_optimizer, variables_to_train=train_vars_d)

            train_vars_g = [
                x for x in train_vars if 'resnet_v1_50_hall' in x.name
            ]
            minimizing_g = slim.learning.create_train_op(
                g_loss, g_optimizer, variables_to_train=train_vars_g)

            ###################################################################
            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_hall_train = utils.accuracy(net_hall_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_hall = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_hall_test, labels), tf.float32))
            ###################################################################

    summ_d_loss = tf.summary.scalar('d_loss', d_loss)
    summ_d_loss_real = tf.summary.scalar('d_loss_real', d_loss_real)
    summ_d_loss_fake = tf.summary.scalar('d_loss_fake', d_loss_fake)
    summ_g_loss = tf.summary.scalar('g_loss', g_loss)
    summ_loss_hall_rect_static = tf.summary.scalar('loss_euclid_hall_layer',
                                                   loss_hall_rect_static)
    summ_loss_hall_rect_static2 = tf.summary.scalar('loss_euclid_hall_2',
                                                    loss_hall_rect_static2)
    summ_acc_train_hall = tf.summary.scalar('acc_train_hall', acc_hall_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summary_train = tf.summary.merge([
        summ_d_loss, summ_d_loss_real, summ_d_loss_fake, summ_g_loss,
        summ_loss_hall_rect_static, summ_loss_hall_rect_static2,
        summ_acc_train_hall, summ_acc_train_depth
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_val_hall = tf.summary.scalar('acc_val_hall', accuracy_value_)
    summ_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summ_acc_test_hall = tf.summary.scalar('acc_test_hall', accuracy_value_)
    summ_acc_test_depth = tf.summary.scalar('acc_test_depth', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.ckpt == '':
            sys.exit('Please specify the depth checkpoint')
        restorers.restore_weights_s2_5_gan_depth(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_hall = summ_acc_val_hall
                step_summ_depth = summ_acc_val_depth
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_hall = summ_acc_test_hall
                step_summ_depth = summ_acc_test_depth
            try:
                accum_correct_depth = accum_correct_hall = 0
                while True:
                    n_correct_hall1, n_correct_depth1 = sess.run(
                        [n_correct_hall, n_correct_depth],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_depth += n_correct_depth1
                    accum_correct_hall += n_correct_hall1
            except tf.errors.OutOfRangeError:
                acc_hall = accum_correct_hall / step_samples
                acc_depth = accum_correct_depth / step_samples
                summ_hall_acc = sess.run(step_summ_hall,
                                         feed_dict={accuracy_value_: acc_hall})
                summary_writer.add_summary(summ_hall_acc, value_step)
                summ_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(summ_depth_acc, value_step)
                utils.double_log(f_log, 'Hall acc = %s \n' % str(acc_hall))
                utils.double_log(f_log, 'Depth acc = %s \n' % str(acc_depth))
                return acc_hall

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:
                        _, _, summ_train = sess.run(
                            [minimizing_d, minimizing_g, summary_train],
                            feed_dict={
                                handle: train_handle,
                                nr_frames: parsers.time_bottleneck,
                                is_training: True
                            })
                        summary_writer.add_summary(summ_train, n_step)
                    else:
                        sess.run(
                            [minimizing_d, minimizing_g],
                            feed_dict={
                                handle: train_handle,
                                nr_frames: parsers.time_bottleneck,
                                is_training: True
                            })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
Exemplo n.º 8
0
def train(exp_id, files, args):
    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]

        if args.noise > 0:
            images_depth_stacked = images_depth_stacked * \
                tf.random_normal(shape=tf.shape(
                    images_depth_stacked), mean=1, stddev=args.noise, dtype=tf.float32)
            images_depth_stacked = tf.saturate_cast(images_depth_stacked,
                                                    dtype=tf.uint8)
            images_depth_stacked = tf.to_float(images_depth_stacked)
        elif args.noise < 0:
            images_depth_stacked = tf.zeros(tf.shape(images_depth_stacked),
                                            tf.float32)

        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth = resnet_v1.resnet_one_stream_main(
                images_depth,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_depth',
                gpu_id=args.gpu0,
                is_training=is_training,
                bottleneck=True)
            net_rgb_out, endpoints_rgb = resnet_v1.resnet_one_stream_main(
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope='resnet_v1_50_rgb',
                gpu_id='/gpu:1',
                is_training=is_training,
                bottleneck=False)

            # predictions for each video are the avg of frames' predictions
            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_rgb_train = tf.reshape(
                net_rgb_out, [-1, parsers.time_bottleneck, n_classes])
            net_rgb_train = tf.reduce_mean(net_rgb_train, axis=1)
            net_combined_train = tf.add(net_depth_train, net_rgb_train) / 2.0

            # TEST ###############################
            net_rgb_test = tf.reshape(net_rgb_out,
                                      [-1, parsers.time_bottleneck, n_classes])
            net_rgb_test = tf.reduce_mean(net_rgb_test, axis=1)
            net_depth_test = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)
            net_combined_test = tf.add(net_rgb_test, net_depth_test) / 2.0

            # losses ##########################################################
            loss_combined = slim.losses.softmax_cross_entropy(
                net_combined_train, labels)
            loss_depth = slim.losses.softmax_cross_entropy(
                net_depth_train, labels)
            loss_rgb = slim.losses.softmax_cross_entropy(net_rgb_train, labels)

            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            minimizing = slim.learning.create_train_op(loss_combined,
                                                       optimizer)

            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_rgb_train = utils.accuracy(net_rgb_train, labels)
            acc_combined_train = utils.accuracy(net_combined_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_rgb = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_rgb_test, labels), tf.float32))
            n_correct_combined = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_combined_test, labels),
                        tf.float32))

    summ_loss_combined = tf.summary.scalar('loss_combined', loss_combined)
    summ_loss_depth = tf.summary.scalar('loss_depth', loss_depth)
    summ_loss_rgb = tf.summary.scalar('loss_rgb', loss_rgb)
    summ_acc_train_rgb = tf.summary.scalar('acc_train_rgb', acc_rgb_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summ_acc_train_combined = tf.summary.scalar('acc_train_combined',
                                                acc_combined_train)
    summary_train = tf.summary.merge([
        summ_acc_train_rgb, summ_acc_train_depth, summ_acc_train_combined,
        summ_loss_depth, summ_loss_rgb, summ_loss_combined
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summary_acc_val_rgb = tf.summary.scalar('acc_val_rgb', accuracy_value_)
    summary_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summary_acc_val_combined = tf.summary.scalar('acc_val_combined',
                                                 accuracy_value_)
    summary_acc_test_rgb = tf.summary.scalar('acc_test_rgb', accuracy_value_)
    summary_acc_test_depth = tf.summary.scalar('acc_test_depth',
                                               accuracy_value_)
    summary_acc_test_combined = tf.summary.scalar('acc_test_combined',
                                                  accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        f_log.flush()

        sess.run(tf.global_variables_initializer())
        if args.ckpt == '':
            restorers.restore_weights_s2(sess, s1_rgb_ckpt, s1_depth_ckpt)
        else:
            restorers.restore_weights_s2_continue(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_rgb = summary_acc_val_rgb
                step_summ_depth = summary_acc_val_depth
                step_summ_combined = summary_acc_val_combined
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_rgb = summary_acc_test_rgb
                step_summ_depth = summary_acc_test_depth
                step_summ_combined = summary_acc_test_combined
            try:
                accum_correct_rgb = accum_correct_depth = accum_correct_combined_val = 0
                while True:
                    n_correct_rgb1, n_correct_depth1, n_correct_combined1 = sess.run(
                        [n_correct_rgb, n_correct_depth, n_correct_combined],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_rgb += n_correct_rgb1
                    accum_correct_depth += n_correct_depth1
                    accum_correct_combined_val += n_correct_combined1
            except tf.errors.OutOfRangeError:
                acc_rgb = accum_correct_rgb / step_samples
                acc_depth = accum_correct_depth / step_samples
                acc_combined = accum_correct_combined_val / step_samples
                sum_rgb_acc = sess.run(step_summ_rgb,
                                       feed_dict={accuracy_value_: acc_rgb})
                summary_writer.add_summary(sum_rgb_acc, value_step)
                sum_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(sum_depth_acc, value_step)
                sum_combined_acc = sess.run(
                    step_summ_combined,
                    feed_dict={accuracy_value_: acc_combined})
                summary_writer.add_summary(sum_combined_acc, value_step)
                utils.double_log(f_log,
                                 'Depth accuracy = %s \n' % str(acc_depth))
                utils.double_log(f_log, 'RGB accuracy = %s \n' % str(acc_rgb))
                utils.double_log(
                    f_log, 'combined accuracy = %s \n' % str(acc_combined))
                return acc_combined

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:  # get summaries
                        _, summary = sess.run([minimizing, summary_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run([minimizing],
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d \n" %
                    (best_acc, best_epoch)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()
Exemplo n.º 9
0
def train(exp_id, files, args):
    hallucination_layer_true = 'resnet_v1_50_depth/block4'
    hallucination_layer_hall = 'resnet_v1_50_hall/block4'
    scope_rgb = 'hall'

    log_path = './log'
    ckpt_path = './checkpoint'

    # dataset ######################################################
    train_filenames, val_filenames, test_filenames = utils.get_tfrecords(
        args.eval_mode, files['data'], dataset=args.dset)
    n_classes = utils.get_n_classes(args.dset)

    with tf.device('/cpu:0'):
        dset_train = tf.contrib.data.TFRecordDataset(train_filenames,
                                                     compression_type="GZIP")
        dset_train = dset_train.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=True))
        seed = tf.placeholder(tf.int64, shape=())
        dset_train = dset_train.shuffle(100, seed=seed)
        dset_train = dset_train.batch(args.batch_sz)

        if val_filenames:
            dset_val = tf.contrib.data.TFRecordDataset(val_filenames,
                                                       compression_type="GZIP")
            dset_val = dset_val.map(
                lambda x: parsers._parse_fun_2stream(x, is_training=False))
            dset_val = dset_val.batch(args.batch_sz)

        dset_test = tf.contrib.data.TFRecordDataset(test_filenames,
                                                    compression_type="GZIP")
        dset_test = dset_test.map(
            lambda x: parsers._parse_fun_2stream(x, is_training=False))
        dset_test = dset_test.batch(args.batch_sz)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.contrib.data.Iterator.from_string_handle(
            handle, dset_train.output_types, dset_train.output_shapes)

        train_iterator = dset_train.make_initializable_iterator()
        if val_filenames:
            val_iterator = dset_val.make_initializable_iterator()
        test_iterator = dset_test.make_initializable_iterator()

        next_element = iterator.get_next()
        images_depth_stacked = next_element[0]  # [batch, pooled_frames, h,w,c]
        images_rgb_stacked = next_element[1]
        if args.dset == 'uwa3dii':  # because tfrecords labels are [1,30]
            labels = next_element[2] - 1
        elif 'ntu' in args.dset or args.dset == 'nwucla':
            labels = next_element[2]
        labels = tf.reshape(labels, [-1])
        labels = tf.one_hot(labels, n_classes)

        rgb_stack_shape = tf.shape(images_rgb_stacked)
        depth_stack_shape = tf.shape(images_depth_stacked)
        # reshape to [batch * pooled_frames, h,w,c]
        images_rgb = tf.reshape(
            images_rgb_stacked,
            [rgb_stack_shape[0] * rgb_stack_shape[1], 224, 224, 3])
        images_depth = tf.reshape(
            images_depth_stacked,
            [depth_stack_shape[0] * depth_stack_shape[1], 224, 224, 3])

    # -----TF.CONFIGPROTO------###########################################
    tf_config = tf.ConfigProto(log_device_placement=True)
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True

    # tf Graph input ##############################################
    with tf.device(args.gpu0):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            is_training = tf.placeholder(tf.bool, [])
            nr_frames = parsers.time_bottleneck

            net_depth_out, endpoints_depth, net_hall_out, endpoints_hall = resnet_v1.resnet_twostream_main(
                images_depth,
                images_rgb,
                nr_frames,
                num_classes=n_classes,
                scope_rgb_stream=scope_rgb,
                depth_training=False,
                is_training=is_training,
                interaction=args.interaction)

            # TRAIN ###############################
            net_depth_train = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_train = tf.reduce_mean(net_depth_train, axis=1)
            net_hall_train = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_train = tf.reduce_mean(net_hall_train, axis=1)

            # TEST ###############################
            net_hall_test = tf.reshape(
                net_hall_out, [-1, parsers.time_bottleneck, n_classes])
            net_hall_test = tf.reduce_mean(net_hall_test, axis=1)
            net_depth_test = tf.reshape(
                net_depth_out, [-1, parsers.time_bottleneck, n_classes])
            net_depth_test = tf.reduce_mean(net_depth_test, axis=1)

            # losses ##########################################################
            loss_hard_hall_class = slim.losses.softmax_cross_entropy(
                net_hall_train, labels)  # hard labels loss

            T = 10  # Temperature T
            teacher_soft_labels = utils.softmax_distill(net_depth_train, T)
            student_soft_labels = utils.softmax_distill(net_hall_train, T)
            loss_soft_hall_class = utils.cross_entropy(student_soft_labels,
                                                       teacher_soft_labels)

            lambda1 = 0.5  # lambda immitation - contrib of soft_label_entropy
            loss_hall_distill = (1 - lambda1) * loss_hard_hall_class + \
                lambda1 * loss_soft_hall_class

            loss_hall_rect_static = utils.loss_hall_rect(
                endpoints_depth[hallucination_layer_true],
                endpoints_hall[hallucination_layer_hall])

            lambda2 = 0.5  # strength of cross entropy loss
            lambda3 = 0.01
            loss = lambda2 * loss_hall_distill + \
                (1 - lambda2) * loss_hall_rect_static * lambda3

            optimizer = tf.train.AdamOptimizer(
                learning_rate=args.learning_rate)
            # freezing depth
            to_remove = [
                x for x in list(
                    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
                if 'resnet_v1_50_depth/' in x.name
            ]
            train_vars = [
                x for x in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
                if x not in to_remove
            ]
            minimizing = slim.learning.create_train_op(
                loss, optimizer, variables_to_train=train_vars)

            ###################################################################
            acc_depth_train = utils.accuracy(net_depth_train, labels)
            acc_hall_train = utils.accuracy(net_hall_train, labels)

            n_correct_depth = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_depth_test, labels),
                        tf.float32))
            n_correct_hall = tf.reduce_sum(
                tf.cast(utils.correct_pred(net_hall_test, labels), tf.float32))

    summ_loss_combined = tf.summary.scalar('loss', loss)
    summ_loss_hall_l2 = tf.summary.scalar('loss_hall_l2',
                                          loss_hall_rect_static)
    summ_loss_hard_hall = tf.summary.scalar('loss_hard_hall_class',
                                            loss_hard_hall_class)
    summ_loss_soft_hall = tf.summary.scalar('loss_soft_hall_class',
                                            loss_soft_hall_class)
    summ_acc_train_hall = tf.summary.scalar('acc_train_hall', acc_hall_train)
    summ_acc_train_depth = tf.summary.scalar('acc_train_depth',
                                             acc_depth_train)
    summ_train = tf.summary.merge([
        summ_acc_train_hall, summ_acc_train_depth, summ_loss_hall_l2,
        summ_loss_hard_hall, summ_loss_combined, summ_loss_soft_hall
    ])

    accuracy_value_ = tf.placeholder(tf.float32, shape=())
    summ_acc_val_hall = tf.summary.scalar('acc_val_hall', accuracy_value_)
    summ_acc_val_depth = tf.summary.scalar('acc_val_depth', accuracy_value_)
    summ_acc_test_hall = tf.summary.scalar('acc_test_hall', accuracy_value_)
    summ_acc_test_depth = tf.summary.scalar('acc_test_depth', accuracy_value_)
    test_saver = tf.train.Saver(max_to_keep=3)

    with tf.Session(config=tf_config) as sess:
        train_handle = sess.run(train_iterator.string_handle())
        if val_filenames:
            val_handle = sess.run(val_iterator.string_handle())
        test_handle = sess.run(test_iterator.string_handle())

        summary_writer = tf.summary.FileWriter(
            os.path.join(log_path, args.dset, exp_id), sess.graph)

        f_log = open(os.path.join(log_path, args.dset, exp_id, 'log.txt'), 'a')
        utils.double_log(
            f_log, '\n###############################################\n' +
            exp_id + '\n#####################################\n')
        f_log.write(' '.join(sys.argv[:]) + '\n')
        utils.update_details_file(f_log,
                                  lambda1=lambda1,
                                  lambda2=lambda2,
                                  lambda3=lambda3)

        sess.run(tf.global_variables_initializer())
        restorers.restore_weights_s3(sess, args.ckpt)

        def val_test(value_step, mode='val'):
            if mode == 'val':
                if not val_filenames:
                    return -1
                utils.double_log(f_log, "eval validation set \n")
                sess.run(val_iterator.initializer)
                step_handle = val_handle
                step_samples = len(val_filenames)
                step_summ_hall = summ_acc_val_hall
                step_summ_depth = summ_acc_val_depth
            elif mode == 'test':
                utils.double_log(f_log, "eval test set \n")
                sess.run(test_iterator.initializer)
                step_handle = test_handle
                step_samples = len(test_filenames)
                step_summ_hall = summ_acc_test_hall
                step_summ_depth = summ_acc_test_depth
            try:
                accum_correct_depth = accum_correct_hall = 0
                while True:
                    n_correct_hall1, n_correct_depth1 = sess.run(
                        [n_correct_hall, n_correct_depth],
                        feed_dict={
                            handle: step_handle,
                            is_training: False
                        })
                    accum_correct_depth += n_correct_depth1
                    accum_correct_hall += n_correct_hall1
            except tf.errors.OutOfRangeError:
                acc_hall = accum_correct_hall / step_samples
                acc_depth = accum_correct_depth / step_samples
                summ_hall_acc = sess.run(step_summ_hall,
                                         feed_dict={accuracy_value_: acc_hall})
                summary_writer.add_summary(summ_hall_acc, value_step)
                summ_depth_acc = sess.run(
                    step_summ_depth, feed_dict={accuracy_value_: acc_depth})
                summary_writer.add_summary(summ_depth_acc, value_step)
                utils.double_log(f_log, 'Hall acc = %s \n' % str(acc_hall))
                utils.double_log(f_log, 'Depth acc = %s \n' % str(acc_depth))
                return acc_hall

        if args.just_eval:
            val_test(-1, mode='test')
            f_log.close()
            summary_writer.close()
            return

        val_test(-1, mode='val')
        val_test(-1, mode='test')
        n_step = 0
        best_acc = best_epoch = best_step = -1
        for epoch in range(args.n_epochs):
            utils.double_log(f_log, 'epoch %s \n' % str(epoch))
            sess.run(train_iterator.initializer, feed_dict={seed: epoch})
            try:
                while True:
                    print(n_step)
                    if n_step % 100 == 0:
                        _, summary = sess.run([minimizing, summ_train],
                                              feed_dict={
                                                  handle: train_handle,
                                                  is_training: True
                                              })
                        summary_writer.add_summary(summary, n_step)
                    else:
                        sess.run([minimizing],
                                 feed_dict={
                                     handle: train_handle,
                                     is_training: True
                                 })
                    n_step += 1
            except tf.errors.OutOfRangeError:
                acc_validation = val_test(n_step, mode='val')

            if val_filenames:
                acc_epoch = acc_validation
            else:
                continue
            if acc_epoch >= best_acc:
                best_acc = acc_epoch
                best_epoch = epoch
                best_step = n_step
                test_saver.save(sess,
                                os.path.join(ckpt_path, args.dset, exp_id,
                                             'test/model.ckpt'),
                                global_step=n_step)

        utils.double_log(f_log, "Optimization Finished!\n")
        if val_filenames:  # restore best validation model
            utils.double_log(
                f_log,
                str("Best Validation Accuracy: %f at epoch %d %d\n" %
                    (best_acc, best_epoch, best_step)))
            variables_to_restore = slim.get_variables_to_restore()
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(
                sess,
                os.path.join(ckpt_path, args.dset, exp_id,
                             'test/model.ckpt-' + str(best_step)))
        else:
            test_saver.save(sess,
                            os.path.join(ckpt_path, args.dset, exp_id,
                                         'test/model.ckpt'),
                            global_step=n_step)

        val_test(n_step + 1, mode='test')
        f_log.close()
        summary_writer.close()