Exemplo n.º 1
0
def test_acc(path, class_nums, growth_rate, depth):
    inputs = tf.placeholder(tf.float32, [None, 32, 32, 3])
    labels = tf.placeholder(tf.int64, [None])
    train_phase = tf.placeholder(tf.bool)
    logits = DenseNet(inputs,
                      nums_out=class_nums,
                      growth_rate=growth_rate,
                      train_phase=train_phase,
                      depth=depth)
    pred = softmax(logits)
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(tf.argmax(pred, axis=1), labels), tf.float32))
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, "./save_para//.\\densenet.ckpt")
    data, labels_ = read_cifar_data(path)
    acc = 0
    for i in range(data.shape[0] // 100):
        acc += sess.run(accuracy,
                        feed_dict={
                            inputs: data[i * 100:i * 100 + 100],
                            labels: labels_[i * 100:i * 100 + 100],
                            train_phase: False
                        })
    return acc / (data.shape[0] // 100)
Exemplo n.º 2
0
def test_densenet(modelpath, batch_size):
    dataLoader = DataLoader()
    net = DenseNet.build_densenet()
    net.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    net.build((1,64,64,3))
    net.load_weights(modelpath)
    test_images, test_labels = dataLoader.get_batch_test(batch_size)
    net.evaluate(test_images, test_labels, verbose=2)
Exemplo n.º 3
0
def train(batch_size, class_nums, growth_rate, weight_decay, depth, cifar10_path, train_epoch, lr):
    inputs = tf.placeholder(tf.float32, [None, 32, 32, 3])
    labels = tf.placeholder(tf.int64, [None])
    train_phase = tf.placeholder(tf.bool)
    learning_rate = tf.placeholder(tf.float32)
    logits = DenseNet(inputs, nums_out=class_nums, growth_rate=growth_rate, train_phase=train_phase, depth=depth)
    pred = softmax(logits)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), labels), tf.float32))
    one_hot_label = to_OneHot(labels, class_nums)
    cross_entropy_loss = tf.reduce_mean(-tf.log(tf.reduce_sum(pred * one_hot_label, axis=1) + 1e-10))
    regular = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
    Opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9, use_nesterov=True).minimize(cross_entropy_loss + weight_decay * regular)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    path = cifar10_path + "data_batch_"
    valid_path = cifar10_path + "data_batch_5"
    loss_list = []
    train_acc_list = []
    test_acc_list = []
    saver = tf.train.Saver()
    # saver.restore(sess, "./save_para//.\\densenet.ckpt")
    # saver.restore(sess, "./save_para/densenet.ckpt")
    for epoch in range(train_epoch):
        if epoch == train_epoch // 2 or epoch == train_epoch * 3 // 4:
            lr /= 10
        for i in range(1, 6):
            if i != 5:
                data, labels_ = read_cifar_data(path + str(i))
                data, labels_ = shuffle(data, labels_)
            else:
                data, labels_ = read_cifar_data(path + str(i))
                data, labels_ = shuffle(data[:5000], labels_[:5000])
            for j in range(data.shape[0] // batch_size - 1):
                batch_data = data[j * batch_size:j * batch_size + batch_size, :, :, :]
                batch_labels = labels_[j * batch_size:j * batch_size + batch_size]
                [_, loss, acc] = sess.run([Opt, cross_entropy_loss, accuracy], feed_dict={inputs: batch_data, labels: batch_labels, train_phase: True, learning_rate: lr})
                loss_list.append(loss)
                train_acc_list.append(acc)
                if j % 100 == 0:
                    print("Epoch: %d, iter: %d, loss: %f, train_acc: %f"%(epoch, j, loss, acc))
                    np.savetxt("loss.txt", loss_list)
                    np.savetxt("train_acc.txt", train_acc_list)
                    np.savetxt("test_acc.txt", test_acc_list)
            if ((epoch + 1) % 5) == 0:
                vali_acc = validation_acc(inputs, labels, train_phase, accuracy, sess, valid_path)
                test_acc_list.append(vali_acc)
                print("Validation Accuracy: %f"%(vali_acc))
                saver.save(sess, "./save_para/densenet.ckpt")



# if __name__ == "__main__":
#     train(batch_size=64, class_nums=10, growth_rate=12, weight_decay=1e-4, depth=40, train_epoch=5)
Exemplo n.º 4
0
def train_densenet(batch_size, epoch):
    dataLoader = DataLoader()
    # build callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint('{epoch}_epoch_densenet_weight.h5',
        save_weights_only=True,
        verbose=1,
        save_freq='epoch')
    # build model
    net = DenseNet.build_densenet()
    net.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-6),loss='sparse_categorical_crossentropy', metrics=['accuracy'])


    # 详细参数见官方文档:https://tensorflow.google.cn/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator?hl=en
    data_generate = ImageDataGenerator(
        featurewise_center=False,# 将输入数据的均值设置为0
        samplewise_center=False, # 将每个样本的均值设置为0
        featurewise_std_normalization=False,  # 将输入除以数据标准差,逐特征进行
        samplewise_std_normalization=False,   # 将每个输出除以其标准差
        zca_epsilon=1e-6,        # ZCA白化的epsilon值,默认为1e-6
        zca_whitening=False,     # 是否应用ZCA白化
        rotation_range=10,        # 随机旋转的度数范围,输入为整数
        width_shift_range=0.1,   # 左右平移,输入为浮点数,大于1时输出为像素值
        height_shift_range=0.1,  # 上下平移,输入为浮点数,大于1时输出为像素值
        shear_range=0.,          # 剪切强度,输入为浮点数
        zoom_range=0.1,          # 随机缩放,输入为浮点数
        channel_shift_range=0.,  # 随机通道转换范围,输入为浮点数
        fill_mode='nearest',     # 输入边界以外点的填充方式,还有constant,reflect,wrap三种填充方式
        cval=0.,                 # 用于填充的值,当fill_mode='constant'时生效
        horizontal_flip=True,    # 随机水平翻转
        vertical_flip=False,     # 随机垂直翻转
        rescale=None,            # 重缩放因子,为None或0时不进行缩放
        preprocessing_function=None,  # 应用于每个输入的函数
        data_format='channels_last',   # 图像数据格式,默认为channels_last
        validation_split=0.0
      )
    # 引用自:https://www.jianshu.com/p/1576da1abd71

    train_images,train_labels = dataLoader.get_batch_train(60000)
    net.fit(
        data_generate.flow(train_images, train_labels, 
            batch_size=batch_size, 
            shuffle=True, 
            #save_to_dir='resource/images'
        ), 
        steps_per_epoch=len(train_images) // batch_size,
        epochs=epoch,
        callbacks=[checkpoint],
        shuffle=True)
Exemplo n.º 5
0
def train_densenet(batch_size, epoch):
    dataLoader = DataLoader()
    # build callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint(f'./weight/{epoch}_epoch_densenet_weight.h5', save_best_only=True, save_weights_only=True, verbose=1, save_freq='epoch')
    # build model
    net = DenseNet.build_densenet()
    net.compile(tf.keras.optimizers.Adam(lr=0.001, decay=1e-6), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # num_iter = dataLoader.num_train//batch_size
    # for e in range(epoch):
    #     for i in range(num_iter):
    #         train_images, train_labels = dataLoader.get_batch_train(batch_size)
    #         net.fit(train_images, train_labels, shuffle=False, batch_size=batch_size, validation_split=0.1, callbacks=[checkpoint])
    #     net.save_weights("./weight/"+str(e+1)+"epoch_iter"+str(i)+"_resnet_weight.h5")

    data_generate = ImageDataGenerator(
        featurewise_center=False,
        samplewise_center=False,
        featurewise_std_normalization=False,
        samplewise_std_normalization=False,
        zca_epsilon=1e-6,
        zca_whitening=False,
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.,
        zoom_range=0.1,
        channel_shift_range=0,
        fill_mode='nearest',
        cval=0.,
        horizontal_flip=True,
        vertical_flip=False,
        rescale=None,
        preprocessing_function=None,
        data_format='channels_last',
        validation_split=0.0)

    train_images, train_labels = dataLoader.get_batch_train(60000)
    net.fit(data_generate.flow(train_images, train_labels, batch_size=batch_size, shuffle=True,),
            steps_per_epoch=len(train_images)//batch_size,
            epochs=epoch,
            callbacks=[checkpoint],
            shuffle=True)
Exemplo n.º 6
0
    args.add_argument('--epoch', type=int, default=1000)
    args.add_argument('--batch_size', type=int, default=64)
    args.add_argument('--num_classes', type=int, default=600)
    args.add_argument('--input_shape', type=int, default=(256, 256, 3))
    args.add_argument('--sbow_shape', type=int, default=(128, ))
    args.add_argument('--train', type=bool, default=False)
    args.add_argument('--updateDB', type=bool, default=False)
    args.add_argument('--eval', type=bool, default=False)
    args.add_argument('--model_path', type=str, default="./checkpoint/finish")
    args.add_argument('--dataset_path', type=str, default="./data/images/")
    args.add_argument('--checkpoint_path', type=str, default="./checkpoint/")
    args.add_argument('--checkpoint_inteval', type=int, default=10)
    args.add_argument('--k', type=int, default=21)

    config = args.parse_args()

model = DenseNet()
op = Adam(lr=0.001,
          beta_1=0.9,
          beta_2=0.999,
          epsilon=1e-10,
          decay=0.008,
          amsgrad=False)
model.compile(loss=losses.logcosh, optimizer=op, metrics=['mae'])

hist = model.fit(X_train,
                 y_train,
                 epochs=51,
                 batch_size=24,
                 validation_data=(X_test, y_test),
                 verbose=2)
Exemplo n.º 7
0
def main():

    init_parser()
    opt.manualSeed = 1
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    save_dir = os.path.join(opt.save_root_dir, subject_names[opt.test_index])

    try:
        os.mkdir(opt.save_root_dir)
        os.mkdir(save_dir)
        print('create save dir')
    except:
        print('dir already exist!')

    logging.basicConfig(format='%(asctime)s %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        filename=os.path.join(save_dir, 'train.log'),
                        level=logging.INFO)
    logging.info('======================================================')

    # load data
    train_data = MSRA_Dataset(root_path=file_path, opt=opt, train=True)
    train_dataloder = DataLoader(train_data,
                                 batch_size=opt.batchSize,
                                 shuffle=True,
                                 num_workers=int(opt.workers))
    test_data = MSRA_Dataset(root_path=file_path, opt=opt, train=False)
    test_dataloder = DataLoader(test_data,
                                batch_size=opt.batchSize,
                                shuffle=False,
                                num_workers=int(opt.workers))
    print('#Train data:', len(train_data), '#Test data:', len(test_data))
    # print(opt)

    # define model,loss and optimizer
    net = DenseNet()
    # if opt.model != '':
    #     net.load_state_dict(torch.load(os.path.join(save_dir, opt.model)))
    '''hardware problem'''
    net.cuda()
    # print(net)

    criterion = nn.MSELoss(size_average=True).cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=opt.learning_rate,
                          momentum=0.9,
                          weight_decay=0.0005)

    if opt.optimizer != '':
        optimizer.load_state_dict(
            torch.load(os.path.join(save_dir, opt.optimizer)))
    # auto adjust learning rate, divided by 10 after 50 rpoch
    scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

    train_len = len(train_data)
    test_len = len(test_data)

    for epoch in range(opt.nepoch):
        # adjust learning rate
        scheduler.step(epoch)
        # adjest_lr(optimizer, epoch)
        print('======>>>>> Online epoch: #%d/%d, lr=%f, Test: %s <<<<<======' %
              (epoch + 1, opt.nepoch, scheduler.get_lr()[0],
               subject_names[opt.test_index]))

        # train step
        train_mse, train_mse_wld, timer = train(net, train_dataloder,
                                                criterion, optimizer)

        # time cost
        timer = timer / train_len
        print('==> time to learn 1 sample = %f (ms)' % (timer * 1000))

        # print mse
        train_mse = train_mse / train_len
        train_mse_wld = train_mse_wld / train_len
        print('mean-square error of 1 sample: %f, #train_data = %d' %
              (train_mse, train_len))
        print('average estimation error in world coordinate system: %f (mm)' %
              (train_mse_wld))

        # save net
        store = bool(epoch == opt.nepoch - 1)
        if store:
            torch.save(net.state_dict(), '%s/net_%d.pth' % (save_dir, epoch))
            torch.save(optimizer.state_dict(),
                       '%s/optimizer_%d.pth' % (save_dir, epoch))
        logging.info(
            'Epoch#%d: train error=%e, train wld error = %f mm,  lr = %f' %
            (epoch + 1, train_mse, train_mse_wld, scheduler.get_lr()[0]))

        # evaluation step
        # store = True

        test_mse, test_wld_err, timer = evaluate(net, test_dataloder,
                                                 criterion, optimizer, store)

        # time cost
        timer = timer / test_len
        print('==> time to learn 1 sample = %f (ms)' % (timer * 1000))

        # print mse
        test_mse = test_mse / test_len
        print('mean-square error of 1 sample: %f, #test_data = %d' %
              (test_mse, test_len))
        test_wld_err = test_wld_err / test_len
        print('average estimation error in world coordinate system: %f (mm)' %
              (test_wld_err))
        logging.info(
            'Epoch#%d:  test error=%e, test wld error = %f mm, lr = %f' %
            (epoch + 1, test_mse, test_wld_err, scheduler.get_lr()[0]))
Exemplo n.º 8
0
def main(args):

    parser = argparse.ArgumentParser()

    parser.add_argument('run_name', metavar='N', type=str, help='name of run')
    parser.add_argument('network_type',
                        metavar='N',
                        type=str,
                        help='name of run')
    parser.add_argument('gpu_id',
                        metavar='G',
                        type=str,
                        help='which gpu to use')
    parser.add_argument('--print_every',
                        metavar='N',
                        type=int,
                        help='number of iterations before printing',
                        default=-1)
    parser.add_argument('--print_network',
                        action='store_true',
                        help='print_network for debugging')
    parser.add_argument('--data_parallel',
                        type=int,
                        nargs='+',
                        default=None,
                        help='paralellize across multiple gpus')

    parser.add_argument('--test', action='store_true', help='test')
    parser.add_argument('--test_print', action='store_true', help='test')
    parser.add_argument('--valid_iters',
                        metavar='I',
                        type=int,
                        default=100,
                        help='number of validation iters to run every epoch')
    parser.add_argument('--csv_file',
                        metavar='CSV',
                        type=str,
                        default=None,
                        help='name of csv file to write to')

    parser.add_argument(
        '--sweep_lambda',
        action='store_true',
        help=
        'preform a sweep over lambda values keeping other settings as in args')
    parser.add_argument(
        '--sweep_c',
        action='store_true',
        help=
        'preform a sweep over lambda values keeping other settings as in args')
    parser.add_argument('--sweep_start',
                        metavar='S',
                        type=float,
                        default=0.0,
                        help='lambda value to start sweep at')
    parser.add_argument('--sweep_stop',
                        metavar='E',
                        type=float,
                        default=0.1,
                        help='lambda value to stop sweep at')
    parser.add_argument('--sweep_step',
                        metavar='E',
                        type=float,
                        default=0.01,
                        help='step_size_between_sweep_points')
    parser.add_argument('--sweep_exp',
                        action='store_true',
                        help='step_size_between_sweep_points')
    parser.add_argument('--sweep_resume',
                        action='store_true',
                        help='resume sweep checkpts')
    parser.add_argument(
        '--sweep_con_runs',
        metavar='C',
        type=int,
        default=1,
        help='number of runs to run with same parameters to validate constiancy'
    )

    #checkpoints
    parser.add_argument('--checkpoint_every',
                        type=int,
                        default=10,
                        help='checkpoint every n epochs')
    parser.add_argument('--load_checkpoint',
                        action='store_true',
                        help='load checkpoint with same name')
    parser.add_argument('--resume',
                        action='store_true',
                        help='resume from epoch we left off of when loading')
    parser.add_argument('--checkpoint',
                        type=str,
                        default=None,
                        help='checkpoint to load')

    #params
    parser.add_argument('--epochs',
                        metavar='N',
                        type=int,
                        help='number of epochs to run for',
                        default=50)
    parser.add_argument('--batch_size',
                        metavar='bs',
                        type=int,
                        default=1024,
                        help='batch size')
    parser.add_argument('--lr',
                        metavar='lr',
                        type=float,
                        help='learning rate',
                        default=1e-3)
    parser.add_argument('--rmsprop',
                        action='store_true',
                        help='use rmsprop optimizer')
    parser.add_argument('--sgd', action='store_true', help='use sgd optimizer')
    parser.add_argument('--lr_reduce_on_plateau',
                        action='store_true',
                        help='update optimizer on plateau')
    parser.add_argument('--lr_exp',
                        action='store_true',
                        help='update optimizer on plateau')
    parser.add_argument('--lr_step',
                        type=int,
                        nargs='+',
                        default=None,
                        help='decrease lr by gamma = 0.1 on these epochs')
    parser.add_argument('--lr_list',
                        type=float,
                        nargs='+',
                        default=None,
                        help='decrease lr by gamma = 0.1 on these epochs')

    parser.add_argument('--l2_reg', type=float, default=0.0)

    DataLoader.add_args(parser)

    #added so default argument come from the network that is loaded
    network_type = args[1]
    if network_type == 'auto_fc':
        AutoFCNetwork.add_args(parser)
        network_class = AutoFCNetwork
    elif network_type == 'auto_conv':
        AutoConvNetwork.add_args(parser)
        network_class = AutoConvNetwork
    elif network_type == 'class_conv':
        ClassifyConvNetwork.add_args(parser)
        network_class = ClassifyConvNetwork
    elif network_type == 'vgg':
        VGG16.add_args(parser)
        network_class = VGG16
    elif network_type == 'res':
        ResidualConvNetwork.add_args(parser)
        network_class = ResidualConvNetwork
    elif network_type == 'res152':
        ResNet152.add_args(parser)
        network_class = ResNet152
    elif network_type == 'dense':
        DenseNet.add_args(parser)
        network_class = DenseNet
    else:
        raise ValueError('unknown network type' + str(network_type))

    args = parser.parse_args(args)

    #***************
    # GPU
    #***************

    if args.data_parallel is not None:
        try:
            del os.environ['CUDA_VISIBLE_DEVICES']
        except KeyError:
            pass
        device = torch.device('cuda:%d' % args.data_parallel[0])
    elif args.gpu_id == '-1':
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        device = torch.device('cpu')
    else:
        print(bcolors.OKBLUE + 'Using GPU' + str(args.gpu_id) + bcolors.ENDC)
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
        device = torch.device('cuda')

    #***************
    # Run
    #***************

    if args.sweep_lambda or args.sweep_c:
        run_dir = args.run_name
        run_ckpt_dir = 'ckpts/%s' % run_dir
        if not os.path.isdir(run_ckpt_dir):
            os.mkdir(run_ckpt_dir)
        val_l = []
        loss_l = []
        op_loss_l = []
        reg_loss_l = []
        nodes_l = []
        images_l = []

        val = args.sweep_start
        while val <= args.sweep_stop:
            print(bcolors.OKBLUE + 'Val: %0.1E' % val + bcolors.ENDC)
            args.run_name = run_name = "%s/l%0.1E" % (run_dir, val)

            for i in range(args.sweep_con_runs):
                if args.sweep_con_runs > 0:
                    args.run_name = run_name + '_' + str(i)

                if not (args.sweep_resume
                        and os.path.isdir('ckpts/' + args.run_name)):

                    print(bcolors.OKBLUE + 'Run: %s' % args.run_name +
                          bcolors.ENDC)

                    if args.sweep_lambda:
                        args.reg_lambda = float(val)
                    elif args.sweep_c:
                        args.reg_c = float(val)
                    run_wrapper = RunWrapper(args, network_class, device)

                    if not args.test:
                        run_wrapper.train()

                    loss, op_loss, reg_loss, rem_nodes, acc = run_wrapper.test(
                        load=args.test)
                    if isinstance(rem_nodes, list):
                        rem_nodes = sum(rem_nodes)
                    x, y_hat = run_wrapper.test_print(plot=False, load=False)

                    val_l.append(val)
                    loss_l.append(loss)
                    op_loss_l.append(op_loss)
                    reg_loss_l.append(reg_loss)
                    nodes_l.append(rem_nodes)
                    if images_l == []:
                        images_l.append(x[:17])
                    images_l.append(y_hat[:17])

                    #hopefully this cleans up the gpu memory
                    del run_wrapper

            if args.sweep_exp:
                if val == 0.0:
                    val = args.sweep_start
                else:
                    val = val * args.sweep_step
            else:
                val = val + args.sweep_step

        loss_l = np.array(loss_l)
        op_loss_l = np.array(op_loss_l)
        reg_loss_l = np.array(reg_loss_l)
        nodes_l = np.array(nodes_l)
        if args.sweep_lambda:
            lc_l = val_l  #[l * args.reg_C for l in val_l]
        elif args.sweep_c:
            lc_l = [c * args.reg_lambda for c in val_l]

        print(lc_l)
        print(nodes_l)

        misc.plot_sweep(run_ckpt_dir, lc_l, op_loss_l, nodes_l)
        #misc.sweep_to_image(images_l, run_ckpt_dir)

    else:
        #default single run behavior
        run_wrapper = RunWrapper(args, network_class, device)

        if args.test_print:
            run_wrapper.test_print()
        elif args.test:
            run_wrapper.test()
        else:
            run_wrapper.train()
Exemplo n.º 9
0
def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    print('+++++++++++++++++++++++++++++++++++++++++++++++++')
    print('[Input Arguments]')
    for arg in args.__dict__:
        print(arg, '->', args.__dict__[arg])
    print('+++++++++++++++++++++++++++++++++++++++++++++++++')

    images = tf.placeholder('float32', shape=[None, *args.image_shape], name='images')  # placeholder for images
    labels = tf.placeholder('float32', shape=[None, args.class_num], name='labels')  # placeholder for labels
    training = tf.placeholder('bool', name='training')  # placeholder for training boolean (is training)
    global_step = tf.get_variable(name='global_step', shape=[], dtype='int64',
                                  trainable=False)  # variable for global step
    best_accuracy = tf.get_variable(name='best_accuracy', dtype='float32', trainable=False, initializer=0.0)

    steps_per_epoch = round(args.train_set_size / args.batch_size)
    learning_rate = tf.train.piecewise_constant(global_step, [round(steps_per_epoch * 0.5 * args.epochs),
                                                              round(steps_per_epoch * 0.75 * args.epochs)],
                                                [args.learning_rate, 0.1 * args.learning_rate,
                                                 0.01 * args.learning_rate])

    # output logit from NN
    output = DenseNet.model(images, args.blocks, args.layers, args.growth_rate, args.class_num, args.compression_factor, args.dropout_rate, args.init_subsample, training=training)
    # output = DenseNet.BC_model(images, args.blocks, args.layers, args.growth_rate, args.class_num, args.compression_factor, args.dropout_rate, training=training)

    # loss and optimizer
    with tf.variable_scope('losses'):
        # loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=output)
        loss = tf.losses.softmax_cross_entropy(labels, output, label_smoothing=args.label_smoothing)
        loss = tf.reduce_mean(loss, name='loss')
        l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()], name='l2_loss')

    with tf.variable_scope('optimizers'):
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=args.momentum, use_nesterov=True)
        # optimizer = tf.train.AdamOptimizer(learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = optimizer.minimize(loss + l2_loss * args.weight_decay, global_step=global_step)
        train_op = tf.group([train_op, update_ops], name='train_op')

    # accuracy
    with tf.variable_scope('accuracy'):
        output = tf.nn.softmax(output, name='output')
        prediction = tf.equal(tf.argmax(output, 1), tf.argmax(labels, 1), name='prediction')
        accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32), name='accuracy')

    # summary
    train_loss_summary = tf.summary.scalar("train_loss", loss)
    val_loss_summary = tf.summary.scalar("val_loss", loss)
    train_accuracy_summary = tf.summary.scalar("train_acc", accuracy)
    val_accuracy_summary = tf.summary.scalar("val_acc", accuracy)

    saver = tf.train.Saver()
    best_saver = tf.train.Saver()

    with tf.Session() as sess:
        merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter(args.checkpoint_dir + '/log', sess.graph)

        sess.run(tf.global_variables_initializer())
        augmentations = [lambda image, label: iter_utils.pad_and_crop(image, label, args.image_shape, 4),
                         iter_utils.flip]
        train_iterator = iter_utils.batch_iterator(args.train_record_dir, None, args.batch_size, augmentations, True)
        train_images_batch, train_labels_batch = train_iterator.get_next()
        val_iterator = iter_utils.batch_iterator(args.val_record_dir, None, args.batch_size)
        val_images_batch, val_labels_batch = val_iterator.get_next()
        sess.run(train_iterator.initializer)
        if args.val_set_size != 0:
            sess.run(val_iterator.initializer)

        # restoring checkpoint
        try:
            saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
            print('checkpoint restored. train from checkpoint')
        except:
            print('failed to load checkpoint. train from the beginning')

        # get initial step
        gstep = sess.run(global_step)
        init_epoch = round(gstep / steps_per_epoch)
        init_epoch = int(init_epoch)

        for epoch_ in range(init_epoch + 1, args.epochs + 1):

            # train
            while True:
                try:
                    train_images, train_labels = sess.run([train_images_batch, train_labels_batch])
                    train_labels = np.eye(args.class_num)[train_labels]
                    gstep, _, loss_, accuracy_, train_loss_sum, train_acc_sum = sess.run(
                        [global_step, train_op, loss, accuracy, train_loss_summary, train_accuracy_summary],
                        feed_dict={images: train_images, labels: train_labels, training: True})
                    print('[global step: ' + str(gstep) + ' / epoch ' + str(epoch_) + '] -> train accuracy: ',
                          accuracy_, ' loss: ', loss_)
                    writer.add_summary(train_loss_sum, gstep)
                    writer.add_summary(train_acc_sum, gstep)
                except tf.errors.OutOfRangeError:
                    sess.run(train_iterator.initializer)
                    break

            predictions = []

            # validation
            if args.val_set_size != 0:
                while True:
                    try:
                        val_images, val_labels = sess.run([val_images_batch, val_labels_batch])
                        val_labels = np.eye(args.class_num)[val_labels]
                        loss_, accuracy_, prediction_, val_loss_sum, val_acc_sum = sess.run(
                            [loss, accuracy, prediction, val_loss_summary, val_accuracy_summary],
                            feed_dict={images: val_images, labels: val_labels, training: False})
                        predictions.append(prediction_)
                        print('[epoch ' + str(epoch_) + '] -> val accuracy: ', accuracy_, ' loss: ', loss_)
                        writer.add_summary(val_loss_sum, gstep)
                        writer.add_summary(val_acc_sum, gstep)
                    except tf.errors.OutOfRangeError:
                        sess.run(val_iterator.initializer)
                        break

            saver.save(sess, args.checkpoint_dir + '/' + args.checkpoint_name, global_step=global_step)

            predictions = np.concatenate(predictions)
            print('best: ', best_accuracy.eval(), '\ncurrent: ', np.mean(predictions))
            if best_accuracy.eval() < np.mean(predictions):
                print('save checkpoint')
                best_accuracy = tf.assign(best_accuracy, np.mean(predictions))
                best_saver.save(sess, args.checkpoint_dir + '/best/' + args.checkpoint_name)