コード例 #1
0
ファイル: train_23.py プロジェクト: twistfatezz/ISEnet
def main():
    # Parse the command line input
    # more about parser see: https://www.cnblogs.com/zknublx/p/6106343.html
    parser = argparse.ArgumentParser(description='Train SESE-Net')

    parser.add_argument('--name',
                        default='/share/donghao/demo10/' + exp +
                        '/trained_models/' + path_ease + '/sesenet',
                        help='project name')

    parser.add_argument('--v1',
                        type=int,
                        default=23,
                        help='chosen form 6 10 14')

    parser.add_argument('--attention-module',
                        type=str,
                        default='sese_block',
                        help='input se_block or sese_block or others')
    # todo
    parser.add_argument('--softmax-flag',
                        type=int,
                        default=0,
                        help='1 for paper 0 for v1')

    parser.add_argument('--switch',
                        type=int,
                        default=0,
                        help='0=standard | 1=pre | 2=identity')

    parser.add_argument('--train-or-not',
                        type=str2bool,
                        default='1',
                        help='train flag')

    parser.add_argument('--validate-or-not',
                        type=str2bool,
                        default='0',
                        help='validate flag')

    parser.add_argument('--test-or-not',
                        type=str2bool,
                        default='1',
                        help='test flag')
    # todo
    parser.add_argument('--data-dir',
                        default='/share/donghao/data2/' + exp + '/6p_dataset',
                        help='dataset info directory')

    parser.add_argument('--epochs',
                        type=int,
                        default=epoch,
                        help='number of training epochs')  # todo fine-tune

    parser.add_argument('--batch-size',
                        type=int,
                        default=batch,
                        help='batch size')  # todo fixed set it as 2^n

    parser.add_argument('--tensorboard-dir',
                        default='/share/donghao/demo10/' + exp + '/logs/' +
                        path_ease + '/sesenet_tb',
                        help='name of the tensorboard data directory')

    parser.add_argument('--pb-model-save-path',
                        default='/share/donghao/demo10/' + exp +
                        '/trained_models/' + path_ease + '/sesenet_pb',
                        help='pb model dir')

    parser.add_argument('--tag-string',
                        default='sesenet_pb',
                        help='tag string for model')

    parser.add_argument('--checkpoint-interval',
                        type=int,
                        default=1,
                        help='checkpoint interval')

    parser.add_argument('--max-to-keep',
                        type=int,
                        default=2,
                        help='num of checkpoint files max to keep')

    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0001,
                        help='weight decay for bn')
    # todo
    parser.add_argument('--lr-decay-method-switch',
                        type=int,
                        default=1,
                        help='0=piecewise|others=exponential')

    parser.add_argument(
        '--lr-values',
        type=str,
        default='0.0001;0.00009;0.00008;0.00005;0.000025;0.00001;0.000005',
        help='learning rate values')  # todo piecewise
    # todo 使用momentum优化器 通常要比adam优化器需要更大的 piecewise_lr 在每个"阶段"
    parser.add_argument('--lr-boundaries',
                        type=str,
                        default='55260;110520;221040;442080;884160;1768320',
                        help='10:20:40:80:160:320 b=8')

    parser.add_argument('--lr-value',
                        type=float,
                        default=0.001,
                        help='learning rate for exp decay')

    parser.add_argument('--decay-steps',
                        type=float,
                        default=num_iteration,
                        help='decay_steps=1 epoch')

    parser.add_argument('--decay-rate',
                        type=float,
                        default=0.97,
                        help='decay rate: for 100 epoch -> lr=0.0001')

    parser.add_argument('--moving-average-decay',
                        type=float,
                        default=0.9997,
                        help='moving avg decay')  # todo 降低decay 企图增加泛化能力

    parser.add_argument('--momentum',
                        type=float,
                        default=0,
                        help='momentum for the optimizer')  # todo 0.9

    parser.add_argument('--adam', type=float, default=1, help='adam optimizer')

    parser.add_argument('--adagrad',
                        type=float,
                        default=0,
                        help='adagrad optimizer')

    parser.add_argument('--rmsprop',
                        type=float,
                        default=0,
                        help='rmsprop optimizer')

    parser.add_argument('--num-samples-for-test-summary',
                        type=int,
                        default=1000,
                        help='range in 0-10044')

    parser.add_argument('--confusion-matrix-normalization',
                        type=str2bool,
                        default='1',
                        help='confusion matrix norm flag')

    parser.add_argument('--class-names',
                        type=list,
                        default=[
                            np.str_('N'),
                            np.str_('S'),
                            np.str_('V'),
                            np.str_('F'),
                            np.str_('Q')
                        ],
                        help='...')

    parser.add_argument('--continue-training',
                        type=str2bool,
                        default=continue_training,
                        help='continue training flag')

    args = parser.parse_args()

    print('[i] Project name:                 ', args.name)
    print('[i] Model categories(18/34):      ', args.v1)
    print('[i] Attention module categories): ', args.attention_module)
    print('[i] Softmax flag(for sese):       ', args.softmax_flag)
    print('[i] Attention switch:             ', args.switch)
    print('[i] Train or not:                 ', args.train_or_not)
    print('[i] Validate or not:              ', args.validate_or_not)
    print('[i] Test or not:                  ', args.test_or_not)
    print('[i] Data directory:               ', args.data_dir)
    print('[i] epochs:                       ', args.epochs)
    print('[i] Batch size:                   ', args.batch_size)
    print('[i] Tensorboard directory:        ', args.tensorboard_dir)
    print('[i] Pb model save path:           ', args.pb_model_save_path)
    print('[i] Tag string:                   ', args.tag_string)
    print('[i] Checkpoint interval:          ', args.checkpoint_interval)
    print('[i] Checkpoint max2keep:          ', args.max_to_keep)
    print('[i] Weight decay(bn):             ', args.weight_decay)
    print('[i] Learning rate decay switch    ', args.lr_decay_method_switch)
    print('[i] Learning rate values:         ', args.lr_values)
    print('[i] Learning rate boundaries:     ', args.lr_boundaries)
    print('[i] Learning rate value(exp):     ', args.lr_value)
    print('[i] Decay steps:                  ', args.decay_steps)
    print('[i] Decay rate:                   ', args.decay_rate)
    print('[i] Moving average decay:         ', args.moving_average_decay)
    print('[i] Momentum:                     ', args.momentum)
    print('[i] Adam:                         ', args.adam)
    print('[i] Adagrad:                      ', args.adagrad)
    print('[i] Rmsprop:                      ', args.rmsprop)
    print('[i] Num of samples for test       ',
          args.num_samples_for_test_summary)
    print('[i] Confusion matrix norm:        ',
          args.confusion_matrix_normalization)
    print('[i] Class names:                  ', args.class_names)
    print('[i] Continue training:            ', args.continue_training)

    # Find an existing checkpoint & continue training...
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(checkpoint_dir=args.name,
                                              latest_filename=None)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1
        # check ckpt path
        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        # find the latest checkpoint file to go on train-process...
        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            # os.path.basename return the final component of a path
            # for e66.ckpt.data-00000-of-00001 we got ckpt_num=66
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch
    else:
        metagraph_file = None
        checkpoint_file = None
        try:
            print('[i] Creating directory             {}...'.format(args.name))
            os.makedirs(args.name)
        except IOError as e:
            print('[!]', str(e))
            return 1

    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir, args.batch_size)
        print('[i] training samples:             ', td.num_train)
        print('[i] classes:                      ', td.num_classes)
        print('[i] ecg_chip size:                ',
              f'({td.sample_width}, {td.sample_length})')
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    print('[i] Training ...')
    with tf.Session(config=config) as sess:
        if start_epoch != 0:
            print('[i] Building model from metagraph...')
            xs, ys = build_from_metagraph(sess, metagraph_file,
                                          checkpoint_file)
            process_flag, loss, accuracy, y_predict, train_op = build_optimizer_from_metagraph(
                sess)
        else:
            print('[i] Building model for dual channel...')
            xs, ys, process_flag, y_predict, loss, accuracy, train_op = \
                build_model_and_optimizer(args.softmax_flag, td.num_classes, td.sample_width,
                                          td.sample_length, td.sample_channel,
                                          args.moving_average_decay,
                                          args.lr_decay_method_switch, args.lr_values,
                                          args.lr_boundaries, args.lr_value, args.decay_steps,
                                          args.decay_rate, args.adam, args.momentum,
                                          args.adagrad, args.rmsprop, args.v1, attention_module=args.attention_module,
                                          switch=args.switch, weight_decay=args.weight_decay)
        # todo a typical wrong implement of initializer for a "reload" model
        # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        # sess.run(init_op)
        # todo right way to do this: initialize w&b as the last update value...
        initialize_uninitialized_variables(sess)
        # create various helpers
        # If `None`, defaults to the list of all saveable objects.
        # can use the code below to save the trainable vars and bn 'mean' & 'variance'
        # var_list = tf.trainable_variables()
        # g_list = tf.global_variables()
        # bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
        # bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
        # var_list += bn_moving_vars
        # saver = tf.train.Saver(var_list=var_list, max_to_keep=3)
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=args.max_to_keep)
        # saver1 = tf.train.Saver(max_to_keep=args.max_to_keep * 2)
        saver2 = tf.train.Saver(max_to_keep=args.max_to_keep * 2)

        # build summaries
        # precision & loss summary
        train_precision = PrecisionSummary(sess, summary_writer, 'train', td.num_train_iter, args.continue_training) \
            if args.train_or_not else None
        test_precision = PrecisionSummary(sess, summary_writer, 'test', td.num_test_iter, args.continue_training) \
            if args.test_or_not else None
        train_loss = LossSummary(sess, summary_writer, 'train', td.num_train_iter, args.continue_training) \
            if args.train_or_not else None
        test_loss = LossSummary(sess, summary_writer, 'test', td.num_test_iter, args.continue_training) \
            if args.test_or_not else None
        # excitation summary
        # train_excitation = ExcitationSummary(sess, summary_writer, 'train', args.attention_module, args.continue_training, path_ease) \
        #     if args.train_or_not else None
        # test_excitation = ExcitationSummary(sess, summary_writer, 'test', args.attention_module, args.continue_training, path_ease) \
        #     if args.test_or_not else None
        # net summary
        net_summary = NetSummary(sess)
        net_summary.build_summaries(args.continue_training)
        if start_epoch == 0:
            net_summaries = sess.run(net_summary.summary_op)  # run
            summary_writer.add_summary(net_summaries, 0)  # add
            summary_writer.flush()  # flush
        # set saved_model builder
        if start_epoch != 0:
            builder = tf.saved_model.builder.SavedModelBuilder(
                args.pb_model_save_path + f'_{start_epoch}')
        else:
            builder = tf.saved_model.builder.SavedModelBuilder(
                args.pb_model_save_path)

        print('[i] Training...')
        # max_acc = 0
        pre_num = 0
        # if train the first time, start_epoch=0 else start_epoch=last_epoch(from checkpoint file...)
        for e in range(start_epoch, args.epochs):
            # Train ->
            # train_cache = []
            # train_flag_lst = [0, 1, 2, 3]
            if args.train_or_not:
                td.train_iter(process='train', num_epoch=args.epochs)
                description = '[i] Train {:>2}/{}'.format(
                    e + 1, args.epochs)  # epoch_No/total_epoch_No -> e+1/200
                for _ in tqdm(iterable=td.train_tqdm_iter,
                              total=td.num_train_iter,
                              desc=description,
                              unit='batches'):
                    x, y = sess.run(
                        td.train_sample)  # array(?,1,512,2) array(?,5)
                    train_dict = {xs: x, ys: y, process_flag: True}
                    _, acc, los = sess.run([train_op, accuracy, loss],
                                           feed_dict=train_dict)
                    # sample
                    # if train_flag_lst:  # if not find all yet
                    #     dense_y = np.argmax(y, axis=1)  # (?,)
                    #     for index, ele in enumerate(dense_y):
                    #         if ele in train_flag_lst:
                    #             train_cache.append((x[index], ele))  # example(1,512,2) label()
                    #             train_flag_lst.remove(ele)
                    #             break
                    # add for precision & ce loss
                    sub_tp = 0
                    sub_fn = 0
                    sub_fp = 0
                    train_precision.add(acc, sub_tp, sub_fn, sub_fp)
                    train_loss.add(values=los)
            # Test ->
            # test_cache = []
            # test_flag_lst = [0, 1, 2, 3]
            if args.test_or_not:
                td.test_iter(process='test', num_epoch=args.epochs)
                description = '[i] Test {:>2}/{}'.format(e + 1, args.epochs)
                for _ in tqdm(iterable=td.test_tqdm_iter,
                              total=td.num_test_iter,
                              desc=description,
                              unit='batches'):
                    x, y = sess.run(
                        td.test_sample)  # array(?,1,512,2) array(?,5)
                    test_dict = {xs: x, ys: y, process_flag: False}
                    acc, los, predict = sess.run([accuracy, loss, y_predict],
                                                 feed_dict=test_dict)
                    # sample for excitation
                    # if test_flag_lst:  # if not find all yet
                    #     dense_y = np.argmax(y, axis=1)  # (?,)
                    #     for index, ele in enumerate(dense_y):
                    #         if ele in test_flag_lst:
                    #             test_cache.append((x[index], ele))  # example(1,512,2) label()
                    #             test_flag_lst.remove(ele)
                    #             break
                    # derive test prediction for sveb
                    sub_tp = 0
                    sub_fn = 0
                    sub_fp = 0
                    for index, _ in enumerate(np.argmax(y, axis=1)):
                        if _ == 1:  # count for recall
                            if predict[index] == 1:
                                sub_tp = sub_tp + 1
                            else:
                                sub_fn = sub_fn + 1
                        else:  # if not sveb -> count for ppr
                            if predict[index] == 1:
                                sub_fp = sub_fp + 1
                            pass
                    # add for precision & ce loss
                    test_precision.add(acc, sub_tp, sub_fn, sub_fp)
                    test_loss.add(values=los)
            # check
            if not args.train_or_not and not args.validate_or_not and not args.test_or_not:
                exit('[!] No procedures implemented!')

            # todo check
            # sess.graph.finalize()
            # todo push & flush tb
            # self, train, valid, test, train_lst, valid_lst, test_lst, xs, epoch
            if args.train_or_not:
                # train_excitation.push(train_cache, xs, e)
                train_precision.push(e)
                train_loss.push(e)
            if args.test_or_not:
                # test_excitation.push(test_cache, xs, e)
                test_precision.push(e)
                test_loss.push(e)
            # run-add net w&b
            net_summaries = sess.run(
                net_summary.summary_op
            )  # run again to derive 'next-step' summary
            summary_writer.add_summary(net_summaries, e + 1)  # add again

            # flush all(summaries of loss/precision/ecg_chip & summaries of ecgnet) protocol buf into disk
            summary_writer.flush()

            # save checkpoint
            if (e + 1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e + 1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)
            # for best acc
            sveb_ppr = test_precision.ppr
            sveb_recall = test_precision.recall
            if sveb_recall >= 0.6:
                num = 1 * sveb_ppr + 1 * sveb_recall
                if (e + 1) % args.checkpoint_interval == 0 and num >= pre_num:
                    checkpoint3 = '{}/highest/e{}.ckpt'.format(
                        args.name, e + 1)
                    saver2.save(sess, checkpoint3)
                    # refresh
                    pre_num = num if num > pre_num else pre_num

        # close writer
        summary_writer.close()

        # after all epochs goes out, save pb model...
        print('[i] Saving pb model(after training steps goes up)...')
        builder.add_meta_graph_and_variables(sess, [args.tag_string])
        builder.save()
        print('[i] programme finished!')
コード例 #2
0
def main():
    # Parse the command line input
    # more about parser see: https://www.cnblogs.com/zknublx/p/6106343.html
    parser = argparse.ArgumentParser(description='Train SE-Net')

    parser.add_argument('--name',
                        default='/share/donghao/demo6/' + exp +
                        '/trained_models/' + path_ease + '/senet',
                        help='project name')
    # todo do not change!
    parser.add_argument('--v1',
                        type=str2bool,
                        default=18,
                        help='v1=18: res-18 | v1=34: res-34')
    # todo do not change!
    parser.add_argument('--attention-module',
                        type=str,
                        default='se_block',
                        help='input se_block or sese_block or others')
    # todo do not change! no matters
    parser.add_argument(
        '--softmax-flag',
        type=str2bool,
        default='0',
        help='only for sese_block | 0=single channel 1=all channel')
    # todo do not change!
    parser.add_argument('--switch',
                        type=int,
                        default=0,
                        help='0=standard | 1=pre | 2=identity')

    parser.add_argument('--train-or-not',
                        type=str2bool,
                        default='1',
                        help='train flag')

    parser.add_argument('--validate-or-not',
                        type=str2bool,
                        default='0',
                        help='validate flag')

    parser.add_argument('--test_models-or-not',
                        type=str2bool,
                        default='1',
                        help='test_models flag')

    parser.add_argument('--data-dir',
                        default='/share/donghao/demo5/' + exp + '/np_dataset',
                        help='dataset info directory')
    # todo
    parser.add_argument('--epochs',
                        type=int,
                        default=epoch,
                        help='number of training epochs')  # todo fine-tune
    # todo
    parser.add_argument('--batch-size',
                        type=int,
                        default=batch,
                        help='batch size')  # todo fixed set it as 2^n
    # todo
    parser.add_argument('--tensorboard-dir',
                        default='/share/donghao/demo6/' + exp + '/logs/' +
                        path_ease + '/senet_tb',
                        help='name of the tensorboard data directory')
    # todo
    parser.add_argument('--pb-model-save-path',
                        default='/share/donghao/demo6/' + exp +
                        '/trained_models/' + path_ease + '/senet_pb',
                        help='pb model dir')
    # todo for t0-exp-senet tag_string='channel01_1d_resnet_pb_model'
    parser.add_argument('--tag-string',
                        default='senet_pb',
                        help='tag string for model')

    parser.add_argument('--checkpoint-interval',
                        type=int,
                        default=1,
                        help='checkpoint interval')

    parser.add_argument('--max-to-keep',
                        type=int,
                        default=2,
                        help='num of checkpoint files max to keep')

    parser.add_argument('--weight-decay',
                        type=float,
                        default=0.0001,
                        help='weight decay for bn')

    parser.add_argument('--lr-decay-method-switch',
                        type=int,
                        default=1,
                        help='0=piecewise|others=exponential')

    parser.add_argument('--lr-values',
                        type=str,
                        default='0.001;0.0005;0.0001;0.00001',
                        help='learning rate values')  # todo piecewise

    parser.add_argument('--lr-boundaries',
                        type=str,
                        default='353481;706962;2120886',
                        help='learning rate change boundaries (in batches)')

    parser.add_argument('--lr-value',
                        type=float,
                        default=0.0001,
                        help='learning rate for exp decay')  # todo exp decay

    parser.add_argument('--decay-steps',
                        type=float,
                        default=num_iteration,
                        help='decay_steps=1 epoch')

    parser.add_argument('--decay-rate',
                        type=float,
                        default=0.99,
                        help='decay rate: for 100 epoch -> lr=0.0001')

    parser.add_argument('--moving-average-decay',
                        type=float,
                        default=0.9999,
                        help='moving avg decay')

    parser.add_argument('--momentum',
                        type=float,
                        default=0,
                        help='momentum for the optimizer')  # todo 0.9

    parser.add_argument('--adam', type=float, default=1,
                        help='adam optimizer')  # todo 控制变量进行和单一channel进行对比实验

    parser.add_argument('--adagrad',
                        type=float,
                        default=0,
                        help='adagrad optimizer')

    parser.add_argument('--rmsprop',
                        type=float,
                        default=0,
                        help='rmsprop optimizer')

    parser.add_argument('--num-samples-for-test_models-summary',
                        type=int,
                        default=1000,
                        help='range in 0-10044')

    parser.add_argument('--confusion-matrix-normalization',
                        type=str2bool,
                        default='1',
                        help='confusion matrix norm flag')

    parser.add_argument('--class-names',
                        type=list,
                        default=[
                            np.str_('N'),
                            np.str_('S'),
                            np.str_('V'),
                            np.str_('F'),
                            np.str_('Q')
                        ],
                        help='...')
    # todo
    parser.add_argument('--continue-training',
                        type=str2bool,
                        default=continue_training,
                        help='continue training flag')

    args = parser.parse_args()

    print('[i] Project name:                 ', args.name)
    print('[i] Model categories(1=v1|0=v2):  ', args.v1)
    print('[i] Attention module categories): ', args.attention_module)
    print('[i] Softmax flag(for sese):       ', args.softmax_flag)
    print('[i] Attention switch:             ', args.switch)
    print('[i] Train or not:                 ', args.train_or_not)
    print('[i] Validate or not:              ', args.validate_or_not)
    print('[i] Test or not:                  ', args.test_or_not)
    print('[i] Data directory:               ', args.data_dir)
    print('[i] epochs:                       ', args.epochs)
    print('[i] Batch size:                   ', args.batch_size)
    print('[i] Tensorboard directory:        ', args.tensorboard_dir)
    print('[i] Pb model save path:           ', args.pb_model_save_path)
    print('[i] Tag string:                   ', args.tag_string)
    print('[i] Checkpoint interval:          ', args.checkpoint_interval)
    print('[i] Checkpoint max2keep:          ', args.max_to_keep)
    print('[i] Weight decay(bn):             ', args.weight_decay)
    print('[i] Learning rate decay switch    ', args.lr_decay_method_switch)
    print('[i] Learning rate values:         ', args.lr_values)
    print('[i] Learning rate boundaries:     ', args.lr_boundaries)
    print('[i] Learning rate value(exp):     ', args.lr_value)
    print('[i] Decay steps:                  ', args.decay_steps)
    print('[i] Decay rate:                   ', args.decay_rate)
    print('[i] Moving average decay:         ', args.moving_average_decay)
    print('[i] Momentum:                     ', args.momentum)
    print('[i] Adam:                         ', args.adam)
    print('[i] Adagrad:                      ', args.adagrad)
    print('[i] Rmsprop:                      ', args.rmsprop)
    print('[i] Num of samples for test_models       ',
          args.num_samples_for_test_summary)
    print('[i] Confusion matrix norm:        ',
          args.confusion_matrix_normalization)
    print('[i] Class names:                  ', args.class_names)
    print('[i] Continue training:            ', args.continue_training)

    # Find an existing checkpoint & continue training...
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(checkpoint_dir=args.name,
                                              latest_filename=None)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1
        # check ckpt path
        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        # find the latest checkpoint file to go on train-process...
        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            # os.path.basename return the final component of a path
            # for e66.ckpt.data-00000-of-00001 we got ckpt_num=66
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch
    else:
        metagraph_file = None
        checkpoint_file = None
        try:
            print('[i] Creating directory             {}...'.format(args.name))
            os.makedirs(args.name)
        except IOError as e:
            print('[!]', str(e))
            return 1

    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir, args.batch_size)
        print('[i] training samples:             ', td.num_train)
        print('[i] validation samples:           ', td.num_valid)
        print('[i] classes:                      ', td.num_classes)
        print('[i] ecg_chip size:                ',
              f'({td.sample_width}, {td.sample_length})')
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    print('[i] Training ...')
    with tf.Session(config=config) as sess:
        if start_epoch != 0:
            print('[i] Building model from metagraph...')
            xs, ys = build_from_metagraph(sess, metagraph_file,
                                          checkpoint_file)
            loss, accuracy, y_predict, train_op = build_optimizer_from_metagraph(
                sess)
        else:
            print('[i] Building model for dual channel...')
            xs, ys, y_predict, loss, accuracy, train_op = \
                build_model_and_optimizer(args.softmax_flag, td.num_classes, td.sample_width,
                                          td.sample_length, td.sample_channel,
                                          args.moving_average_decay,
                                          args.lr_decay_method_switch, args.lr_values,
                                          args.lr_boundaries, args.lr_value, args.decay_steps,
                                          args.decay_rate, args.adam, args.momentum,
                                          args.adagrad, args.rmsprop, args.v1, attention_module=args.attention_module,
                                          switch=args.switch, weight_decay=args.weight_decay)

        # todo a typical wrong implement of initializer for a "reload" model
        # init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        # sess.run(init_op)
        # todo right way to do this: initialize w&b as the last update value...
        initialize_uninitialized_variables(sess)
        # create various helpers
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=args.max_to_keep)
        saver1 = tf.train.Saver(max_to_keep=args.max_to_keep)

        # build summaries
        train_precision = PrecisionSummary(sess, summary_writer, 'train', td.num_train_iter, args.continue_training) \
            if args.train_or_not else None
        valid_precision = PrecisionSummary(sess, summary_writer, 'valid', td.num_valid_iter, args.continue_training) \
            if args.validate_or_not else None
        test_precision = PrecisionSummary(sess, summary_writer, 'test_models', td.num_test_iter, args.continue_training) \
            if args.test_or_not else None
        train_loss = LossSummary(sess, summary_writer, 'train', td.num_train_iter, args.continue_training) \
            if args.train_or_not else None
        valid_loss = LossSummary(sess, summary_writer, 'valid', td.num_valid_iter, args.continue_training) \
            if args.validate_or_not else None
        test_loss = LossSummary(sess, summary_writer, 'test_models', td.num_test_iter, args.continue_training) \
            if args.test_or_not else None
        # self, session, writer, process_flag, restore=False
        train_excitation = ExcitationSummary(sess, summary_writer, 'train', args.attention_module, args.continue_training, path_ease) \
            if args.train_or_not else None
        valid_excitation = ExcitationSummary(sess, summary_writer, 'valid', args.attention_module, args.continue_training, path_ease) \
            if args.validate_or_not else None
        test_excitation = ExcitationSummary(sess, summary_writer, 'test_models', args.attention_module, args.continue_training, path_ease) \
            if args.test_or_not else None
        # set saved_model builder
        if start_epoch != 0:
            builder = tf.saved_model.builder.SavedModelBuilder(
                args.pb_model_save_path + f'_{start_epoch}')
        else:
            builder = tf.saved_model.builder.SavedModelBuilder(
                args.pb_model_save_path)

        print('[i] Training...')
        max_acc = 0
        # if train the first time, start_epoch=0 else start_epoch=last_epoch(from checkpoint file...)
        for e in range(start_epoch, args.epochs):
            # Train ->
            train_cache = []
            train_flag_lst = [0, 1, 2, 3, 4]
            if args.train_or_not:
                td.train_iter(process='train', num_epoch=args.epochs)
                description = '[i] Train {:>2}/{}'.format(
                    e + 1, args.epochs)  # epoch_No/total_epoch_No -> e+1/200
                for _ in tqdm(iterable=td.train_tqdm_iter,
                              total=td.num_train_iter,
                              desc=description,
                              unit='batches'):
                    x, y = sess.run(
                        td.train_sample)  # array(?,1,512,2) array(?,5)
                    train_dict = {xs: x, ys: y}
                    _, acc, los = sess.run([train_op, accuracy, loss],
                                           feed_dict=train_dict)
                    # sample
                    if train_flag_lst:  # if not find all yet
                        dense_y = np.argmax(y, axis=1)  # (?,)
                        for index, ele in enumerate(dense_y):
                            if ele in train_flag_lst:
                                train_cache.append(
                                    (x[index],
                                     ele))  # example(1,512,2) label()
                                train_flag_lst.remove(ele)
                                break
                    # add for precision & ce loss
                    train_precision.add(acc=acc)
                    train_loss.add(values=los)
            # Validate ->
            validate_cache = []
            validate_flag_lst = [0, 1, 2, 3, 4]
            if args.validate_or_not:
                td.valid_iter(process='validate', num_epoch=args.epochs)
                description = '[i] Valid {:>2}/{}'.format(e + 1, args.epochs)
                for _ in tqdm(iterable=td.valid_tqdm_iter,
                              total=td.num_valid_iter,
                              desc=description,
                              unit='batches'):
                    x, y = sess.run(
                        td.valid_sample)  # array(?,1,512,2) array(?,)
                    validate_dict = {xs: x, ys: y}
                    acc, los = sess.run([accuracy, loss],
                                        feed_dict=validate_dict)
                    # sample
                    if validate_flag_lst:  # if not find all yet
                        dense_y = np.argmax(y, axis=1)  # (?,)
                        for index, ele in enumerate(dense_y):
                            if ele in train_flag_lst:
                                validate_cache.append(
                                    (x[index],
                                     ele))  # example(1,512,2) label()
                                validate_flag_lst.remove(ele)
                                break
                    # add for precision & ce loss
                    valid_precision.add(acc=acc)
                    valid_loss.add(values=los)
            # Test ->
            test_cache = []
            test_flag_lst = [0, 1, 2, 3, 4]
            if args.test_or_not:
                td.test_iter(process='test_models', num_epoch=args.epochs)
                description = '[i] Test {:>2}/{}'.format(e + 1, args.epochs)
                for _ in tqdm(iterable=td.test_tqdm_iter,
                              total=td.num_test_iter,
                              desc=description,
                              unit='batches'):
                    x, y = sess.run(
                        td.test_sample)  # array(?,1,512,2) array(?,5)
                    test_dict = {xs: x, ys: y}
                    acc, los = sess.run([accuracy, loss], feed_dict=test_dict)
                    # sample
                    if test_flag_lst:  # if not find all yet
                        dense_y = np.argmax(y, axis=1)  # (?,)
                        for index, ele in enumerate(dense_y):
                            if ele in test_flag_lst:
                                test_cache.append(
                                    (x[index],
                                     ele))  # example(1,512,2) label()
                                test_flag_lst.remove(ele)
                                break
                    # add for precision & ce loss
                    test_precision.add(acc=acc)
                    test_loss.add(values=los)
            # check
            if not args.train_or_not and not args.validate_or_not and not args.test_or_not:
                exit('[!] No procedures implemented!')

            # todo check
            # sess.graph.finalize()
            # todo push & flush tb
            # self, train, valid, test_models, train_lst, valid_lst, test_lst, xs, epoch
            if args.train_or_not:
                train_excitation.push(train_cache, xs, e)
                train_precision.push(e)
                train_loss.push(e)
            if args.validate_or_not:
                valid_excitation.push(validate_cache, xs, e)
                valid_precision.push(e)
                valid_loss.push(e)
            if args.test_or_not:
                test_excitation.push(test_cache, xs, e)
                test_precision.push(e)
                test_loss.push(e)

            # flush all(summaries of loss/precision/ecg_chip & summaries of ecgnet) protocol buf into disk
            summary_writer.flush()

            # save checkpoint
            if (e + 1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e + 1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)
            avg_acc = test_precision.precision_cache
            # todo 这个脚本的功能是找到准确率最高的模型 ^-^
            if (e + 1) % args.checkpoint_interval == 0 and avg_acc >= max_acc:
                checkpoint2 = '{}/highest/e{}.ckpt'.format(args.name, e + 1)
                saver1.save(sess, checkpoint2)
            # refresh max_acc
            max_acc = avg_acc if avg_acc > max_acc else max_acc

        # close writer
        summary_writer.close()

        # after all epochs goes out, save pb model...
        print('[i] Saving pb model(after training steps goes up)...')
        builder.add_meta_graph_and_variables(sess, [args.tag_string])
        builder.save()
        print('[i] programme finished!')