def train(args): batch_size = args.batch_size epoch = args.epoch network = args.network opt = args.opt train = unpickle(args.train_path) test = unpickle(args.test_path) train_data = train[b'data'] test_data = test[b'data'] x_train = train_data.reshape(train_data.shape[0], 3, 32, 32) x_train = x_train.transpose(0, 2, 3, 1) y_train = train[b'fine_labels'] x_test = test_data.reshape(test_data.shape[0], 3, 32, 32) x_test = x_test.transpose(0, 2, 3, 1) y_test = test[b'fine_labels'] x_train = norm_images(x_train) x_test = norm_images(x_test) print('-------------------------------') print('--train/test len: ', len(train_data), len(test_data)) print('--x_train norm: ', compute_mean_var(x_train)) print('--x_test norm: ', compute_mean_var(x_test)) print('--batch_size: ', batch_size) print('--epoch: ', epoch) print('--network: ', network) print('--opt: ', opt) print('-------------------------------') if not os.path.exists('./trans/tran.tfrecords'): generate_tfrecord(x_train, y_train, './trans/', 'tran.tfrecords') generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords') dataset = tf.data.TFRecordDataset('./trans/tran.tfrecords') dataset = dataset.map(parse_function) dataset = dataset.shuffle(buffer_size=50000) dataset = dataset.batch(batch_size) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() x_input = tf.placeholder(tf.float32, [None, 32, 32, 3]) y_input = tf.placeholder(tf.int64, [ None, ]) y_input_one_hot = tf.one_hot(y_input, 100) lr = tf.placeholder(tf.float32, []) if network == 'resnet50': prob = resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnet34': prob = resnet34(x_input, is_training=True, reuse=False, kernel_initializer=tf.contrib.layers. variance_scaling_initializer()) elif network == 'resnet18': prob = resnet18(x_input, is_training=True, reuse=False, kernel_initializer=tf.contrib.layers. variance_scaling_initializer()) elif network == 'seresnet50': prob = se_resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnet110': prob = resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnet110': prob = se_resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnet152': prob = se_resnet152(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnet152': prob = resnet152(x_input, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnet_fixed': prob = get_resnet(x_input, 152, trainable=True, w_init=tf.orthogonal_initializer()) elif network == 'densenet121': prob = densenet121(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'densenet169': prob = densenet169(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'densenet201': prob = densenet201(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'densenet161': prob = densenet161(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'densenet100bc': prob = densenet100bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'densenet190bc': prob = densenet190bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnext50': prob = resnext50(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnext110': prob = resnext110(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) elif network == 'resnext152': prob = resnext152(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnext50': prob = se_resnext50(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnext110': prob = se_resnext110(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) elif network == 'seresnext152': prob = se_resnext152(x_input, reuse=False, is_training=True, cardinality=32, kernel_initializer=tf.orthogonal_initializer()) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=prob, labels=y_input_one_hot)) conv_var = [var for var in tf.trainable_variables() if 'conv' in var.name] l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in conv_var]) loss = l2_loss * 5e-4 + loss if opt == 'adam': opt = tf.train.AdamOptimizer(lr) elif opt == 'momentum': opt = tf.train.MomentumOptimizer(lr, 0.9) elif opt == 'nesterov': opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = opt.minimize(loss) logit_softmax = tf.nn.softmax(prob) acc = tf.reduce_mean( tf.cast(tf.equal(tf.argmax(logit_softmax, 1), y_input), tf.float32)) #-------------------------------Test----------------------------------------- if not os.path.exists('./trans/tran.tfrecords'): generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords') dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords') dataset_test = dataset_test.map(parse_test) dataset_test = dataset_test.shuffle(buffer_size=10000) dataset_test = dataset_test.batch(128) iterator_test = dataset_test.make_initializable_iterator() next_element_test = iterator_test.get_next() if network == 'resnet50': prob_test = resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'resnet18': prob_test = resnet18(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'resnet34': prob_test = resnet34(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'seresnet50': prob_test = se_resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'resnet110': prob_test = resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'seresnet110': prob_test = se_resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'seresnet152': prob_test = se_resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'resnet152': prob_test = resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'seresnet_fixed': prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True) elif network == 'densenet121': prob_test = densenet121(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'densenet169': prob_test = densenet169(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'densenet201': prob_test = densenet201(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'densenet161': prob_test = densenet161(x_input, is_training=False, reuse=True, kernel_initializer=None) elif network == 'densenet100bc': prob_test = densenet100bc(x_input, reuse=True, is_training=False, kernel_initializer=None) elif network == 'densenet190bc': prob_test = densenet190bc(x_input, reuse=True, is_training=False, kernel_initializer=None) elif network == 'resnext50': prob_test = resnext50(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None) elif network == 'resnext110': prob_test = resnext110(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None) elif network == 'resnext152': prob_test = resnext152(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None) elif network == 'seresnext50': prob_test = se_resnext50(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext110': prob_test = se_resnext110(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext152': prob_test = se_resnext152(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) logit_softmax_test = tf.nn.softmax(prob_test) acc_test = tf.reduce_sum( tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32)) #---------------------------------------------------------------------------- saver = tf.train.Saver(max_to_keep=1, var_list=tf.global_variables()) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True now_lr = 0.001 # Warm Up with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) counter = 0 max_test_acc = -1 for i in range(epoch): sess.run(iterator.initializer) while True: try: batch_train, label_train = sess.run(next_element) _, loss_val, acc_val, lr_val = sess.run( [train_op, loss, acc, lr], feed_dict={ x_input: batch_train, y_input: label_train, lr: now_lr }) counter += 1 if counter % 100 == 0: print('counter: ', counter, 'loss_val', loss_val, 'acc: ', acc_val) if counter % 1000 == 0: print('start test ') sess.run(iterator_test.initializer) avg_acc = [] while True: try: batch_test, label_test = sess.run( next_element_test) acc_test_val = sess.run(acc_test, feed_dict={ x_input: batch_test, y_input: label_test }) avg_acc.append(acc_test_val) except tf.errors.OutOfRangeError: print('end test ', np.sum(avg_acc) / len(y_test)) now_test_acc = np.sum(avg_acc) / len(y_test) if now_test_acc > max_test_acc: print('***** Max test changed: ', now_test_acc) max_test_acc = now_test_acc filename = 'params/distinct/' + network + '_{}.ckpt'.format( counter) saver.save(sess, filename) break except tf.errors.OutOfRangeError: print('end epoch %d/%d , lr: %f' % (i, epoch, lr_val)) now_lr = lr_schedule(i, args.epoch) break
def get_model(x_input, network): if network == 'resnet50': return resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet18': return resnet18(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet34': return resnet34(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet50': return se_resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet110': return resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet110': return se_resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet152': return se_resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet152': return resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet_fixed': return get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True) elif network == 'densenet121': return densenet121(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet169': return densenet169(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet201': return densenet201(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet161': return densenet161(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet100bc': return densenet100bc(x_input, reuse=True, is_training=False, kernel_initializer=None) elif network == 'densenet190bc': return densenet190bc(x_input, reuse=True, is_training=False, kernel_initializer=None) elif network == 'resnext50': return resnext50(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'resnext110': return resnext110(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'resnext152': return resnext152(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext50': return se_resnext50(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext110': return se_resnext110(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext152': return se_resnext152(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None) raise InvalidNetworkName('Network name is invalid!')
def test(args): # train = unpickle('/data/ChuyuanXiong/up/cifar-100-python/train') # train_data = train[b'data'] # x_train = train_data.reshape(train_data.shape[0], 3, 32, 32) # x_train = x_train.transpose(0, 2, 3, 1) test = unpickle(args.test_path) test_data = test[b'data'] x_test = test_data.reshape(test_data.shape[0], 3, 32, 32) x_test = x_test.transpose(0, 2, 3, 1) y_test = test[b'fine_labels'] x_test = norm_images(x_test) # x_test = norm_images_using_mean_var(x_test, *compute_mean_var(x_train)) network = args.network ckpt = args.ckpt x_input = tf.placeholder(tf.float32, [None, 32, 32, 3]) y_input = tf.placeholder(tf.int64, [ None, ]) #-------------------------------Test----------------------------------------- if not os.path.exists('./trans/test.tfrecords'): generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords') dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords') dataset_test = dataset_test.map(parse_test) dataset_test = dataset_test.shuffle(buffer_size=10000) dataset_test = dataset_test.batch(128) iterator_test = dataset_test.make_initializable_iterator() next_element_test = iterator_test.get_next() if network == 'resnet50': prob_test = resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet18': prob_test = resnet18(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet34': prob_test = resnet34(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet50': prob_test = se_resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet110': prob_test = resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet110': prob_test = se_resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet152': prob_test = se_resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'resnet152': prob_test = resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'seresnet_fixed': prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True) elif network == 'densenet121': prob_test = densenet121(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet169': prob_test = densenet169(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet201': prob_test = densenet201(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet161': prob_test = densenet161(x_input, is_training=False, reuse=False, kernel_initializer=None) elif network == 'densenet100bc': prob_test = densenet100bc(x_input, reuse=False, is_training=False, kernel_initializer=None) elif network == 'densenet190bc': prob_test = densenet190bc(x_input, reuse=False, is_training=False, kernel_initializer=None) elif network == 'resnext50': prob_test = resnext50(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'resnext110': prob_test = resnext110(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'resnext152': prob_test = resnext152(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext50': prob_test = se_resnext50(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext110': prob_test = se_resnext110(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None) elif network == 'seresnext152': prob_test = se_resnext152(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None) # prob_test = tf.layers.dense(prob_test, 100, reuse=True, name='before_softmax') logit_softmax_test = tf.nn.softmax(prob_test) acc_test = tf.reduce_sum( tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32)) var_list = tf.trainable_variables() g_list = tf.global_variables() bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] var_list += bn_moving_vars saver = tf.train.Saver(var_list=var_list) config = tf.ConfigProto() config.allow_soft_placement = True config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: saver.restore(sess, ckpt) sess.run(iterator_test.initializer) avg_acc = [] while True: try: batch_test, label_test = sess.run(next_element_test) acc_test_val = sess.run(acc_test, feed_dict={ x_input: batch_test, y_input: label_test }) avg_acc.append(acc_test_val) except tf.errors.OutOfRangeError: print('end test ', np.sum(avg_acc) / len(y_test)) break