コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--gradclip',
                        '-c',
                        type=float,
                        default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot',
                        '-snap',
                        type=int,
                        default=100,
                        help='snapshot epochs for save checkpoint')
    parser.add_argument(
        '--valid',
        '-v',
        default='',
        help='validate directory path contains validate txt file')
    parser.add_argument('--train',
                        '-t',
                        default="train",
                        help='Train directory path contains train txt file')
    parser.add_argument('--database',
                        default="BP4D",
                        help='database to train for')
    parser.add_argument('--lr', '-l', type=float, default=0.001)
    parser.add_argument('--hidden_size',
                        type=int,
                        default=1024,
                        help="hidden_size orignally used in open_crf")
    parser.add_argument('--eval_mode',
                        action='store_true',
                        help='whether to evaluate the model')
    parser.add_argument("--need_cache_graph",
                        "-ng",
                        action="store_true",
                        help="whether to cache factor graph to LRU cache")
    parser.add_argument("--bi_lstm",
                        '-bilstm',
                        action='store_true',
                        help="Use bi_lstm as basic component of temporal_lstm")
    parser.add_argument("--num_attrib",
                        type=int,
                        default=2048,
                        help="node feature dimension")
    parser.add_argument("--resume",
                        action="store_true",
                        help="whether to load npz pretrained file")
    parser.add_argument(
        "--snap_individual",
        action="store_true",
        help="whether to snapshot each individual epoch/iteration")

    parser.set_defaults(test=False)
    args = parser.parse_args()
    print_interval = 1, 'iteration'
    val_interval = 5, 'iteration'

    adaptive_AU_database(args.database)

    # for the StructuralRNN constuctor need first frame factor graph_backup
    dataset = GlobalDataSet(num_attrib=args.num_attrib)
    model = TemporalLSTM(box_num=config.BOX_NUM[args.database],
                         in_size=args.num_attrib,
                         out_size=dataset.label_bin_len,
                         use_bi_lstm=args.bi_lstm,
                         initialW=None)

    train_data = GraphDataset(args.train,
                              attrib_size=args.hidden_size,
                              global_dataset=dataset,
                              need_s_rnn=True,
                              need_cache_factor_graph=args.need_cache_graph)

    train_iter = chainer.iterators.SerialIterator(train_data,
                                                  1,
                                                  shuffle=True,
                                                  repeat=True)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu(args.gpu)

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
    updater = BPTTUpdater(train_iter, optimizer, int(args.gpu))
    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'),
                                       out=args.out)

    print_interval = (1, 'iteration')

    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=print_interval)
    trainer.extend(chainer.training.extensions.PrintReport([
        'iteration', 'epoch', 'elapsed_time', 'lr', 'main/loss',
        "main/accuracy"
    ]),
                   trigger=print_interval)
    log_name = "temporal_lstm.log"
    trainer.extend(
        chainer.training.extensions.LogReport(trigger=print_interval,
                                              log_name=log_name))
    # trainer.extend(chainer.training.extensions.ProgressBar(update_interval=1, training_length=(args.epoch, 'epoch')))
    optimizer_snapshot_name = "{0}_temporal_lstm_optimizer.npz".format(
        args.database)
    trainer.extend(chainer.training.extensions.snapshot_object(
        optimizer, filename=optimizer_snapshot_name),
                   trigger=(args.snapshot, 'iteration'))

    if not args.snap_individual:
        model_snapshot_name = "{0}_temporal_lstm_model.npz".format(
            args.database)
        trainer.extend(chainer.training.extensions.snapshot_object(
            model, filename=model_snapshot_name),
                       trigger=(args.snapshot, 'iteration'))
    else:
        model_snapshot_name = "{0}_temporal_lstm_model_".format(
            args.database) + "{.updater.iteration}.npz"
        trainer.extend(chainer.training.extensions.snapshot_object(
            model, filename=model_snapshot_name),
                       trigger=(args.snapshot, 'iteration'))

    trainer.extend(chainer.training.extensions.ExponentialShift('lr', 0.7),
                   trigger=(5, "epoch"))

    # load pretrained file
    if not args.snap_individual:
        if args.resume and os.path.exists(args.out + os.sep +
                                          model_snapshot_name):
            print("loading model_snapshot_name to model")
            chainer.serializers.load_npz(
                args.out + os.sep + model_snapshot_name, model)
    else:
        if args.resume:
            file_lst = [
                filename[filename.rindex("_") + 1:filename.rindex(".")]
                for filename in os.listdir(args.out)
            ]
            file_no = sorted(map(int, file_lst))[-1]
            model_snapshot_name = "{0}_temporal_lstm_model_{1}.npz".format(
                args.database, file_no)
            chainer.serializers.load_npz(
                args.out + os.sep + model_snapshot_name, model)

    if args.resume and os.path.exists(args.out + os.sep +
                                      optimizer_snapshot_name):
        print("loading optimizer_snapshot_name to optimizer")
        chainer.serializers.load_npz(
            args.out + os.sep + optimizer_snapshot_name, optimizer)

    if chainer.training.extensions.PlotReport.available():
        trainer.extend(chainer.training.extensions.PlotReport(
            ['main/loss'], file_name="train_loss.png"),
                       trigger=(100, "iteration"))
        # trainer.extend(chainer.training.extensions.PlotReport(['opencrf_val/F1','opencrf_val/accuracy'],
        #                                                       file_name="val_f1.png"), trigger=val_interval)

    # if args.valid:
    #     valid_data = S_RNNPlusDataset(args.valid, attrib_size=args.hidden_size, global_dataset=dataset,
    #                                   need_s_rnn=True,need_cache_factor_graph=args.need_cache_graph)  # attrib_size控制open-crf层的weight长度
    #     validate_iter = chainer.iterators.SerialIterator(valid_data, 1, shuffle=False, repeat=False)
    #     crf_evaluator = OpenCRFEvaluator(iterator=validate_iter, target=model, device=args.gpu)
    #     trainer.extend(crf_evaluator, trigger=val_interval, name="opencrf_val")

    trainer.run()
コード例 #2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--step_size',
                        '-ss',
                        type=int,
                        default=3000,
                        help='step_size for lr exponential')
    parser.add_argument('--gradclip',
                        '-c',
                        type=float,
                        default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--pretrain',
                        '-pr',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--snapshot',
                        '-snap',
                        type=int,
                        default=100,
                        help='snapshot iteration for save checkpoint')
    parser.add_argument('--test_mode',
                        action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.add_argument('--valid',
                        '-val',
                        default='',
                        help='Test directory path contains test txt file')
    parser.add_argument('--test',
                        '-tt',
                        default='graph_test',
                        help='Test directory path contains test txt file')
    parser.add_argument('--train',
                        '-tr',
                        default="D:/toy/",
                        help='Train directory path contains train txt file')
    parser.add_argument('--train_edge',
                        default="all",
                        help="train temporal/all to comparision")
    parser.add_argument('--database', default="BP4D", help="BP4D/DISFA")
    parser.add_argument(
        '--use_pure_python',
        action='store_true',
        help=
        'you can use pure python code to check whether your optimized code works correctly'
    )
    parser.add_argument('--lr', '-l', type=float, default=0.1)
    parser.add_argument("--profile",
                        "-p",
                        action="store_true",
                        help="whether to profile to examine speed bottleneck")
    parser.add_argument("--num_attrib",
                        type=int,
                        default=2048,
                        help="node feature dimension")
    parser.add_argument("--need_cache_graph",
                        "-ng",
                        action="store_true",
                        help="whether to cache factor graph to LRU cache")
    parser.add_argument("--eval_mode",
                        '-eval',
                        action="store_true",
                        help="whether to evaluation or not")
    parser.add_argument("--proc_num", "-pn", type=int, default=1)
    parser.add_argument("--resume",
                        action="store_true",
                        help="resume from pretrained model")
    parser.set_defaults(test=False)
    args = parser.parse_args()
    config.OPEN_CRF_CONFIG["use_pure_python"] = args.use_pure_python
    # because we modify config.OPEN_CRF_CONFIG thus will influence the open_crf layer
    from graph_learning.dataset.crf_pact_structure import CRFPackageStructure
    from graph_learning.dataset.graph_dataset import GraphDataset
    from graph_learning.extensions.opencrf_evaluator import OpenCRFEvaluator
    from graph_learning.dataset.graph_dataset_reader import GlobalDataSet
    from graph_learning.updater.bptt_updater import convert
    from graph_learning.extensions.AU_roi_label_split_evaluator import ActionUnitEvaluator
    if args.use_pure_python:

        from graph_learning.model.open_crf.pure_python.open_crf_layer import OpenCRFLayer
    else:
        from graph_learning.model.open_crf.cython.open_crf_layer import OpenCRFLayer

    print_interval = 1, 'iteration'
    val_interval = (5, 'iteration')
    adaptive_AU_database(args.database)
    root_dir = os.path.dirname(os.path.dirname(args.train))
    dataset = GlobalDataSet(num_attrib=args.num_attrib,
                            train_edge=args.train_edge)
    file_name = list(
        filter(lambda e: e.endswith(".txt"), os.listdir(args.train)))[0]
    sample = dataset.load_data(args.train + os.sep + file_name)
    print("pre load done")

    crf_pact_structure = CRFPackageStructure(
        sample, dataset, num_attrib=dataset.num_attrib_type, need_s_rnn=False)
    model = OpenCRFLayer(node_in_size=dataset.num_attrib_type,
                         weight_len=crf_pact_structure.num_feature)

    train_str = args.train
    if train_str[-1] == "/":
        train_str = train_str[:-1]
    trainer_keyword = os.path.basename(train_str)
    trainer_keyword_tuple = tuple(trainer_keyword.split("_"))
    LABEL_SPLIT = config.BP4D_LABEL_SPLIT if args.database == "BP4D" else config.DISFA_LABEL_SPLIT
    if trainer_keyword_tuple not in LABEL_SPLIT:
        return
    # assert "_" in trainer_keyword

    train_data = GraphDataset(args.train,
                              attrib_size=dataset.num_attrib_type,
                              global_dataset=dataset,
                              need_s_rnn=False,
                              need_cache_factor_graph=args.need_cache_graph,
                              get_geometry_feature=False)
    if args.proc_num == 1:
        train_iter = chainer.iterators.SerialIterator(train_data,
                                                      1,
                                                      shuffle=True)
    elif args.proc_num > 1:
        train_iter = chainer.iterators.MultiprocessIterator(
            train_data,
            batch_size=1,
            n_processes=args.proc_num,
            repeat=True,
            shuffle=True,
            n_prefetch=10,
            shared_mem=31457280)
    optimizer = chainer.optimizers.SGD(lr=args.lr)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
    optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))
    updater = StandardUpdater(train_iter, optimizer, converter=convert)
    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'),
                                       out=args.out)

    interval = 1
    if args.test_mode:
        chainer.config.train = False

    trainer.extend(
        PrintReport([
            'iteration',
            'epoch',
            'elapsed_time',
            'lr',
            'main/loss',
            "opencrf_val/main/hit",  #"opencrf_validation/main/U_hit",
            "opencrf_val/main/miss",  #"opencrf_validation/main/U_miss",
            "opencrf_val/main/F1",  #"opencrf_validation/main/U_F1"
            'opencrf_val/main/accuracy',
        ]),
        trigger=print_interval)
    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=print_interval)
    trainer.extend(
        chainer.training.extensions.LogReport(
            trigger=print_interval,
            log_name="open_crf_{}.log".format(trainer_keyword)))

    optimizer_snapshot_name = "{0}_{1}_opencrf_optimizer.npz".format(
        trainer_keyword, args.database)
    model_snapshot_name = "{0}_{1}_opencrf_model.npz".format(
        trainer_keyword, args.database)
    trainer.extend(chainer.training.extensions.snapshot_object(
        optimizer, filename=optimizer_snapshot_name),
                   trigger=(args.snapshot, 'iteration'))

    trainer.extend(chainer.training.extensions.snapshot_object(
        model, filename=model_snapshot_name),
                   trigger=(args.snapshot, 'iteration'))

    if args.resume and os.path.exists(args.out + os.sep + model_snapshot_name):
        print("loading model_snapshot_name to model")
        chainer.serializers.load_npz(args.out + os.sep + model_snapshot_name,
                                     model)
    if args.resume and os.path.exists(args.out + os.sep +
                                      optimizer_snapshot_name):
        print("loading optimizer_snapshot_name to optimizer")
        chainer.serializers.load_npz(
            args.out + os.sep + optimizer_snapshot_name, optimizer)

    # trainer.extend(chainer.training.extensions.ProgressBar(update_interval=1))
    # trainer.extend(chainer.training.extensions.snapshot(),
    #                trigger=(args.snapshot, 'epoch'))

    # trainer.extend(chainer.training.extensions.ExponentialShift('lr', 0.9), trigger=(1, 'epoch'))

    if chainer.training.extensions.PlotReport.available():
        trainer.extend(chainer.training.extensions.PlotReport(
            ['main/loss'],
            file_name="{}_train_loss.png".format(trainer_keyword)),
                       trigger=(100, "iteration"))
        trainer.extend(chainer.training.extensions.PlotReport(
            ['opencrf_val/F1', 'opencrf_val/accuracy'],
            file_name="{}_val_f1.png".format(trainer_keyword)),
                       trigger=val_interval)

    if args.valid:
        valid_data = GraphDataset(
            args.valid,
            attrib_size=dataset.num_attrib_type,
            global_dataset=dataset,
            need_s_rnn=False,
            need_cache_factor_graph=args.need_cache_graph)
        validate_iter = chainer.iterators.SerialIterator(valid_data,
                                                         1,
                                                         repeat=False,
                                                         shuffle=False)
        evaluator = OpenCRFEvaluator(iterator=validate_iter,
                                     target=model,
                                     device=-1)
        trainer.extend(evaluator, trigger=val_interval)

    if args.profile:
        cProfile.runctx("trainer.run()", globals(), locals(), "Profile.prof")
        s = pstats.Stats("Profile.prof")
        s.strip_dirs().sort_stats("time").print_stats()
    else:
        trainer.run()
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', '-e', type=int, default=25,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')  # open_crf layer only works for CPU mode
    parser.add_argument('--step_size', '-ss', type=int, default=3000,
                        help='step_size for lr exponential')
    parser.add_argument('--gradclip', '-c', type=float, default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot', '-snap', type=int, default=1, help='snapshot epochs for save checkpoint')
    parser.add_argument('--test_mode', action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.add_argument("--test", '-tt', default='test',help='Test directory path contains test txt file to evaluation')
    parser.add_argument('--train', '-t', default="train",
                        help='Train directory path contains train txt file')
    parser.add_argument('--database',  default="BP4D",
                        help='database to train for')
    parser.add_argument('--lr', '-l', type=float, default=0.01)
    parser.add_argument('--neighbor_mode', type=NeighborMode, choices=list(NeighborMode), help='1:concat_all,2:attention_fuse,3:random_neighbor,4.no_neighbor')
    parser.add_argument('--spatial_edge_mode', type=SpatialEdgeMode, choices=list(SpatialEdgeMode), help='1:all_edge, 2:configure_edge, 3:no_edge')
    parser.add_argument('--temporal_edge_mode',type=RecurrentType, choices=list(RecurrentType), help='1:rnn, 2:attention_block, 3.point-wise feed forward(no temporal)')
    parser.add_argument("--use_relation_net", action='store_true', help='whether to use st_relation_net instead of space_time_net')
    parser.add_argument("--relation_net_lstm_first", action='store_true',
                        help='whether to use relation_net_lstm_first_forward in st_relation_net')
    parser.add_argument('--use_geometry_features',action='store_true', help='whether to use geometry features')
    parser.add_argument("--num_attrib", type=int, default=2048, help="number of dimension of each node feature")
    parser.add_argument('--geo_num_attrib', type=int, default=4, help='geometry feature length')
    parser.add_argument('--attn_heads', type=int, default=16, help='attention heads number')
    parser.add_argument('--layers', type=int, default=1, help='edge rnn and node rnn layer')
    parser.add_argument("--use_paper_num_label", action="store_true", help="only to use paper reported number of labels"
                                                                           " to train")
    parser.add_argument("--bi_lstm", action="store_true", help="whether to use bi-lstm as Edge/Node RNN")
    parser.add_argument('--weight_decay',type=float,default=0.0005, help="weight decay")
    parser.add_argument("--proc_num",'-proc', type=int,default=1, help="process number of dataset reader")
    parser.add_argument("--resume",action="store_true", help="whether to load npz pretrained file")
    parser.add_argument('--resume_model', '-rm', help='The relative path to restore model file')
    parser.add_argument("--snap_individual", action="store_true", help='whether to snap shot each fixed step into '
                                                                       'individual model file')
    parser.add_argument("--vis", action='store_true', help='whether to visualize computation graph')



    parser.set_defaults(test=False)
    args = parser.parse_args()
    if args.use_relation_net:
        args.out += "_relationnet"
        print("output file to : {}".format(args.out))
    print_interval = 1, 'iteration'
    val_interval = 5, 'iteration'
    print("""
    ======================================
        argument: 
            neighbor_mode:{0}
            spatial_edge_mode:{1}
            temporal_edge_mode:{2}
            use_geometry_features:{3}
    ======================================
    """.format(args.neighbor_mode, args.spatial_edge_mode, args.temporal_edge_mode, args.use_geometry_features))
    adaptive_AU_database(args.database)
    # for the StructuralRNN constuctor need first frame factor graph_backup
    dataset = GlobalDataSet(num_attrib=args.num_attrib, num_geo_attrib=args.geo_num_attrib,
                            train_edge="all")
    file_name = list(filter(lambda e: e.endswith(".txt"), os.listdir(args.train)))[0]

    paper_report_label = OrderedDict()
    if args.use_paper_num_label:
        for AU_idx,AU in sorted(config.AU_SQUEEZE.items(), key=lambda e:int(e[0])):
            if args.database == "BP4D":
                paper_use_AU = config.paper_use_BP4D
            elif args.database =="DISFA":
                paper_use_AU = config.paper_use_DISFA
            if AU in paper_use_AU:
                paper_report_label[AU_idx] = AU
    paper_report_label_idx = list(paper_report_label.keys())
    if not paper_report_label_idx:
        paper_report_label_idx = None


    sample = dataset.load_data(args.train + os.sep + file_name, npy_in_parent_dir=False,
                               paper_use_label_idx=paper_report_label_idx)  # we load first sample for construct S-RNN, it must passed to constructor argument
    crf_pact_structure = CRFPackageStructure(sample, dataset, num_attrib=dataset.num_attrib_type)  # 只读取其中的一个视频的第一帧,由于node个数相对稳定,因此可以construct RNN
    # 因为我们用多分类的hinge loss,所以需要num_label是来自于2进制形式的label数+1(+1代表全0)\

    if args.use_relation_net:
        model = StRelationNetPlus(crf_pact_structure, in_size=dataset.num_attrib_type, out_size=dataset.label_bin_len,
                              database=args.database, neighbor_mode=args.neighbor_mode,
                              spatial_edge_mode=args.spatial_edge_mode, recurrent_block_type=args.temporal_edge_mode,
                              attn_heads=args.attn_heads, dropout=0.5, use_geometry_features=args.use_geometry_features,
                              layers=args.layers, bi_lstm=args.bi_lstm, lstm_first_forward=args.relation_net_lstm_first)
    else:
        model = StAttentioNetPlus(crf_pact_structure, in_size=dataset.num_attrib_type, out_size=dataset.label_bin_len,
                              database=args.database, neighbor_mode=args.neighbor_mode,
                              spatial_edge_mode=args.spatial_edge_mode, recurrent_block_type=args.temporal_edge_mode,
                              attn_heads=args.attn_heads, dropout=0.5, use_geometry_features=args.use_geometry_features,
                              layers=args.layers, bi_lstm=args.bi_lstm)

    # note that the following code attrib_size will be used by open_crf for parameter number, thus we cannot pass dataset.num_attrib_type!
    train_data = GraphDataset(args.train, attrib_size=dataset.num_attrib_type, global_dataset=dataset, need_s_rnn=True,
                              need_cache_factor_graph=False, npy_in_parent_dir=False, get_geometry_feature=True,
                              paper_use_label_idx=paper_report_label_idx)  # train 传入文件夹

    train_iter = chainer.iterators.SerialIterator(train_data, 1, shuffle=True, repeat=True)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        if args.use_relation_net:
            model.st_relation_net.to_gpu(args.gpu)
        else:
            model.st_attention_net.to_gpu(args.gpu)

    specific_key = "all_AU_train"
    if paper_report_label_idx:
        specific_key = "paper_AU_num_train"

    optimizer_snapshot_name = "{0}@{1}@st_attention_network_optimizer@{2}@{3}@{4}@{5}.npz".format(args.database,
                                                                                            specific_key,
                                                                                              args.neighbor_mode,
                                                                                              args.spatial_edge_mode,
                                                                                              args.temporal_edge_mode,
                                                                                              "use_geo" if args.use_geometry_features else "no_geo")
    model_snapshot_name = "{0}@{1}@st_attention_network_model@{2}@{3}@{4}@{5}.npz".format(args.database,
                                                                                          specific_key,
                                                                                      args.neighbor_mode,
                                                                                      args.spatial_edge_mode,
                                                                                      args.temporal_edge_mode,
                                                                                      "use_geo" if args.use_geometry_features else "no_geo")
    if args.snap_individual:
        model_snapshot_name = "{0}@{1}@st_attention_network_model_snapshot_".format(args.database,specific_key)
        model_snapshot_name += "{.updater.iteration}"
        model_snapshot_name += "@{0}@{1}@{2}@{3}.npz".format(args.neighbor_mode,
                                                             args.spatial_edge_mode,
                                                             args.temporal_edge_mode,
                                                             "use_geo" if args.use_geometry_features else "no_geo")
    if os.path.exists(args.out + os.sep + model_snapshot_name):
        print("found trained model file. load trained file: {}".format(args.out + os.sep + model_snapshot_name))
        chainer.serializers.load_npz(args.out + os.sep + model_snapshot_name, model)

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    optimizer.setup(model)
    # optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
    # optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))
    optimizer.add_hook(chainer.optimizer.WeightDecay(rate=args.weight_decay))
    updater = BPTTUpdater(train_iter, optimizer, int(args.gpu))
    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    interval = (1, 'iteration')
    if args.test_mode:
        chainer.config.train = False
    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=print_interval)
    trainer.extend(chainer.training.extensions.PrintReport(
        ['iteration', 'epoch', 'elapsed_time', 'lr',
         'main/loss', "main/accuracy",
         ]), trigger=print_interval)

    log_name = "st_attention_network_{0}@{1}@{2}@{3}@{4}.log".format(args.database,
                                                                      args.neighbor_mode,
                                                                      args.spatial_edge_mode,
                                                                      args.temporal_edge_mode,
                                                                "use_geo" if args.use_geometry_features else "no_geo")

    trainer.extend(chainer.training.extensions.LogReport(trigger=interval,log_name=log_name))
    # trainer.extend(chainer.training.extensions.ProgressBar(update_interval=1, training_length=(args.epoch, 'epoch')))

    trainer.extend(
        chainer.training.extensions.snapshot_object(optimizer,
                                                    filename=optimizer_snapshot_name),
        trigger=(args.snapshot, 'epoch'))

    trainer.extend(
        chainer.training.extensions.snapshot_object(model,
                                                    filename=model_snapshot_name),
        trigger=(args.snapshot, 'epoch'))

    trainer.extend(chainer.training.extensions.ExponentialShift('lr',0.1), trigger=(10, "epoch"))

    if args.resume and os.path.exists(args.out + os.sep + args.resume_model):
        print("loading model_snapshot_name to model")
        chainer.serializers.load_npz(args.out + os.sep + args.resume_model, model)
    if args.resume and os.path.exists(args.out + os.sep + optimizer_snapshot_name):
        print("loading optimizer_snapshot_name to optimizer")
        chainer.serializers.load_npz(args.out + os.sep + optimizer_snapshot_name, optimizer)

    if chainer.training.extensions.PlotReport.available():
        trainer.extend(chainer.training.extensions.PlotReport(['main/loss'],
                                                              file_name="train_loss.png"),
                                                              trigger=val_interval)
        trainer.extend(chainer.training.extensions.PlotReport(['main/accuracy'],
                                                              file_name="train_accuracy.png"), trigger=val_interval)

    trainer.run()
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)'
                        )  # open_crf layer only works for CPU mode
    parser.add_argument(
        "--model", "-m",
        help="pretrained model file path")  # which contains pretrained target
    parser.add_argument("--test",
                        "-tt",
                        default="",
                        help="test txt folder path")
    parser.add_argument("--database",
                        "-db",
                        default="BP4D",
                        help="which database you want to evaluate")
    parser.add_argument(
        "--check",
        "-ck",
        action="store_true",
        help=
        "default not to check the npy file and all list file generate correctly"
    )
    parser.add_argument("--num_attrib",
                        type=int,
                        default=2048,
                        help="feature dimension")
    parser.add_argument("--geo_num_attrib",
                        type=int,
                        default=4,
                        help='geometry feature dimension')
    parser.add_argument("--train_edge",
                        default="all",
                        help="all/spatio/temporal")
    parser.add_argument("--attn_heads", type=int, default=16)
    parser.add_argument("--layers",
                        type=int,
                        default=1,
                        help="layer number of edge/node rnn")
    parser.add_argument(
        "--bi_lstm",
        action="store_true",
        help="whether or not to use bi_lstm as edge/node rnn base")
    parser.add_argument(
        "--use_relation_net",
        action='store_true',
        help='whether to use st_relation_net instead of space_time_net')
    parser.add_argument(
        "--relation_net_lstm_first",
        action='store_true',
        help='whether to use relation_net_lstm_first_forward in st_relation_net'
    )

    args = parser.parse_args()
    adaptive_AU_database(args.database)
    mode_dict = extract_mode(args.model)

    paper_report_label = OrderedDict()
    if mode_dict["use_paper_report_label_num"]:
        for AU_idx, AU in sorted(config.AU_SQUEEZE.items(),
                                 key=lambda e: int(e[0])):
            if args.database == "BP4D":
                paper_use_AU = config.paper_use_BP4D
            elif args.database == "DISFA":
                paper_use_AU = config.paper_use_DISFA
            if AU in paper_use_AU:
                paper_report_label[AU_idx] = AU
    paper_report_label_idx = list(paper_report_label.keys())
    if not paper_report_label_idx:
        paper_report_label_idx = None

    test_dir = args.test if not args.test.endswith("/") else args.test[:-1]
    assert args.database in test_dir
    dataset = GlobalDataSet(num_attrib=args.num_attrib,
                            num_geo_attrib=args.geo_num_attrib,
                            train_edge=args.train_edge)  # ../data_info.json
    file_name = None
    for _file_name in os.listdir(args.test):
        if os.path.exists(args.test + os.sep +
                          _file_name) and _file_name.endswith(".txt"):
            file_name = args.test + os.sep + _file_name
            break
    sample = dataset.load_data(file_name,
                               npy_in_parent_dir=False,
                               paper_use_label_idx=paper_report_label_idx)
    print("pre load done")

    crf_pact_structure = CRFPackageStructure(
        sample, dataset, num_attrib=dataset.num_attrib_type, need_s_rnn=False)
    print("""
        ======================================
        gpu:{4}
        argument: 
                neighbor_mode:{0}
                spatial_edge_mode:{1}
                temporal_edge_mode:{2}
                use_geometry_features:{3}
                use_paper_report_label_num:{5}
        ======================================
        """.format(mode_dict["neighbor_mode"], mode_dict["spatial_edge_mode"],
                   mode_dict["temporal_edge_mode"],
                   mode_dict["use_geo_feature"], args.gpu,
                   mode_dict["use_paper_report_label_num"]))
    if args.use_relation_net:
        model = StRelationNetPlus(
            crf_pact_structure,
            in_size=dataset.num_attrib_type,
            out_size=dataset.label_bin_len,
            database=args.database,
            neighbor_mode=NeighborMode[mode_dict["neighbor_mode"]],
            spatial_edge_mode=SpatialEdgeMode[mode_dict["spatial_edge_mode"]],
            recurrent_block_type=RecurrentType[
                mode_dict["temporal_edge_mode"]],
            attn_heads=args.attn_heads,
            dropout=0.0,
            use_geometry_features=mode_dict["use_geo_feature"],
            layers=args.layers,
            bi_lstm=args.bi_lstm,
            lstm_first_forward=args.relation_net_lstm_first)
    else:
        model = StAttentioNetPlus(
            crf_pact_structure,
            dataset.num_attrib_type,
            dataset.label_bin_len,
            args.database,
            NeighborMode[mode_dict["neighbor_mode"]],
            SpatialEdgeMode[mode_dict["spatial_edge_mode"]],
            RecurrentType[mode_dict["temporal_edge_mode"]],
            attn_heads=args.attn_heads,
            dropout=0.0,
            use_geometry_features=mode_dict["use_geo_feature"],
            layers=args.layers,
            bi_lstm=args.bi_lstm)
    print("loading {}".format(args.model))
    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu(args.gpu)
    with chainer.no_backprop_mode():
        test_data = GraphDataset(directory=test_dir,
                                 attrib_size=dataset.num_attrib_type,
                                 global_dataset=dataset,
                                 need_s_rnn=True,
                                 npy_in_parent_dir=False,
                                 need_cache_factor_graph=False,
                                 get_geometry_feature=True,
                                 paper_use_label_idx=paper_report_label_idx)
        test_iter = chainer.iterators.SerialIterator(test_data,
                                                     1,
                                                     shuffle=False,
                                                     repeat=False)
        au_evaluator = ActionUnitEvaluator(
            test_iter,
            model,
            args.gpu,
            database=args.database,
            paper_report_label=paper_report_label)
        observation = au_evaluator.evaluate()
        with open(
                os.path.dirname(args.model) + os.sep +
                "evaluation_result_{0}@{1}@{2}@{3}@{4}.json".format(
                    args.database, NeighborMode[mode_dict["neighbor_mode"]],
                    SpatialEdgeMode[mode_dict["spatial_edge_mode"]],
                    RecurrentType[mode_dict["temporal_edge_mode"]],
                    mode_dict["use_geo_feature"]), "w") as file_obj:
            file_obj.write(
                json.dumps(observation, indent=4, separators=(',', ': ')))
            file_obj.flush()
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)'
                        )  # open_crf layer only works for CPU mode
    parser.add_argument('--step_size',
                        '-ss',
                        type=int,
                        default=3000,
                        help='step_size for lr exponential')
    parser.add_argument('--gradclip',
                        '-c',
                        type=float,
                        default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot',
                        '-snap',
                        type=int,
                        default=20,
                        help='snapshot epochs for save checkpoint')
    parser.add_argument('--train',
                        '-t',
                        default="train",
                        help='Train directory path contains train txt file')
    parser.add_argument('--database',
                        default="BP4D",
                        help='database to train for')
    parser.add_argument('--lr', '-l', type=float, default=0.01)
    parser.add_argument('--hidden_size',
                        type=int,
                        default=1024,
                        help="the hidden dimension of the middle layers")
    parser.add_argument("--num_attrib",
                        type=int,
                        default=2048,
                        help="number of dimension of each node feature")
    parser.add_argument("--proc_num",
                        '-proc',
                        type=int,
                        default=1,
                        help="process number of dataset reader")
    parser.add_argument("--need_cache_graph",
                        "-ng",
                        action="store_true",
                        help="whether to cache factor graph to LRU cache")
    parser.add_argument("--resume",
                        action="store_true",
                        help="whether to load npz pretrained file")
    parser.add_argument('--atten_heads',
                        type=int,
                        default=4,
                        help="atten heads for parallel learning")
    parser.add_argument('--layer_num',
                        type=int,
                        default=2,
                        help='layer number of GAT')

    args = parser.parse_args()
    print_interval = 1, 'iteration'
    val_interval = 5, 'iteration'

    adaptive_AU_database(args.database)

    box_num = config.BOX_NUM[args.database]
    # for the StructuralRNN constuctor need first frame factor graph_backup
    dataset = GlobalDataSet(num_attrib=args.num_attrib, train_edge="all")
    file_name = list(
        filter(lambda e: e.endswith(".txt"), os.listdir(args.train)))[0]
    dataset.load_data(
        args.train + os.sep + file_name, False
    )  # we load first sample for construct S-RNN, it must passed to constructor argument
    model = GraphAttentionModel(input_dim=dataset.num_attrib_type,
                                hidden_dim=args.hidden_size,
                                class_number=dataset.label_bin_len,
                                atten_heads=args.atten_heads,
                                layers_num=args.layer_num,
                                frame_node_num=box_num)
    # note that the following code attrib_size will be used by open_crf for parameter number, thus we cannot pass dataset.num_attrib_type!
    train_data = GraphDataset(args.train,
                              attrib_size=2048,
                              global_dataset=dataset,
                              need_s_rnn=False,
                              need_cache_factor_graph=args.need_cache_graph,
                              need_adjacency_matrix=True,
                              npy_in_parent_dir=False,
                              need_factor_graph=False)  # train 传入文件夹
    train_iter = chainer.iterators.SerialIterator(train_data,
                                                  1,
                                                  shuffle=True,
                                                  repeat=True)
    if args.gpu >= 0:
        print("using gpu : {}".format(args.gpu))
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu(args.gpu)

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
    updater = BPTTUpdater(train_iter, optimizer, device=args.gpu)
    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'),
                                       out=args.out)

    interval = (1, 'iteration')
    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=print_interval)
    trainer.extend(chainer.training.extensions.PrintReport([
        'iteration',
        'epoch',
        'elapsed_time',
        'lr',
        'main/loss',
        "main/accuracy",
    ]),
                   trigger=print_interval)

    log_name = "GAT.log"
    trainer.extend(
        chainer.training.extensions.LogReport(trigger=interval,
                                              log_name=log_name))
    # trainer.extend(chainer.training.extensions.ProgressBar(update_interval=1, training_length=(args.epoch, 'epoch')))
    optimizer_snapshot_name = "{0}_GAT_optimizer.npz".format(args.database)
    model_snapshot_name = "{0}_GAT_model.npz".format(args.database)
    trainer.extend(chainer.training.extensions.snapshot_object(
        optimizer, filename=optimizer_snapshot_name),
                   trigger=(args.snapshot, 'epoch'))

    trainer.extend(chainer.training.extensions.snapshot_object(
        model, filename=model_snapshot_name),
                   trigger=(args.snapshot, 'epoch'))
    trainer.extend(chainer.training.extensions.ExponentialShift('lr', 0.7),
                   trigger=(5, "epoch"))

    if args.resume and os.path.exists(args.out + os.sep + model_snapshot_name):
        print("loading model_snapshot_name to model")
        chainer.serializers.load_npz(args.out + os.sep + model_snapshot_name,
                                     model)
    if args.resume and os.path.exists(args.out + os.sep +
                                      optimizer_snapshot_name):
        print("loading optimizer_snapshot_name to optimizer")
        chainer.serializers.load_npz(
            args.out + os.sep + optimizer_snapshot_name, optimizer)

    if chainer.training.extensions.PlotReport.available():
        trainer.extend(chainer.training.extensions.PlotReport(
            ['main/loss'], file_name="train_loss.png"),
                       trigger=(100, "iteration"))
        trainer.extend(chainer.training.extensions.PlotReport(
            ['opencrf_val/F1', 'opencrf_val/accuracy'],
            file_name="{}_val_f1.png"),
                       trigger=val_interval)

    trainer.run()
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')  # open_crf layer only works for CPU mode
    parser.add_argument("--target_dir", "-t", default="result", help="pretrained model file path") # which contains pretrained target
    parser.add_argument("--test", "-tt", default="", help="test txt folder path")
    parser.add_argument("--hidden_size", "-hs",default=1024, type=int, help="hidden_size of srnn++")
    parser.add_argument("--database","-db",default="BP4D", help="which database you want to evaluate")
    parser.add_argument("--bi_lstm","-bi", action="store_true", help="srnn++ use bi_lstm or not, if pretrained model use bi_lstm, you must set this flag on")
    parser.add_argument("--check", "-ck", action="store_true", help="default not to check the npy file and all list file generate correctly")
    parser.add_argument("--num_attrib",type=int,default=2048, help="feature dimension")
    parser.add_argument("--train_edge",default="all",help="all/spatio/temporal")
    args = parser.parse_args()
    adaptive_AU_database(args.database)
    test_dir = args.test if not args.test.endswith("/") else args.test[:-1]
    assert args.database in test_dir
    dataset = GlobalDataSet(num_attrib=args.num_attrib, train_edge=args.train_edge) # ../data_info.json
    file_name = None
    for folder in os.listdir(args.test):
        if os.path.isdir(args.test + os.sep + folder):
            for _file_name in os.listdir(args.test + os.sep + folder):
                file_name = args.test + os.sep + folder  + os.sep +_file_name
                break
            break
    sample = dataset.load_data(file_name)
    print("pre load done")


    target_dict = {}
    need_srnn = False
    use_crf = False
    for model_path in os.listdir(args.target_dir):  # all model pretrained file in 3_fold_1's one folder, 3_fold_2 in another folder
        if model_path.endswith("model.npz"):
            assert ("opencrf" in model_path or "srnn_plus" in model_path)
            if "opencrf" in model_path:
                assert need_srnn == False
                use_crf = True
                # note that open_crf layer doesn't support GPU
                crf_pact_structure = CRFPackageStructure(sample, dataset, num_attrib=dataset.num_attrib_type, need_s_rnn=False)
                model = OpenCRFLayer(node_in_size=dataset.num_attrib_type, weight_len=crf_pact_structure.num_feature)
                print("loading {}".format(args.target_dir + os.sep + model_path, model))
                chainer.serializers.load_npz(args.target_dir + os.sep + model_path, model)
            elif "srnn_plus" in model_path:
                crf_pact_structure = CRFPackageStructure(sample, dataset, num_attrib=args.hidden_size, need_s_rnn=True)
                with_crf = "crf" in model_path
                need_srnn = True
                model = StructuralRNNPlus(crf_pact_structure, in_size=dataset.num_attrib_type,
                                          out_size=dataset.num_label,
                                          hidden_size=args.hidden_size, with_crf=with_crf,
                                          use_bi_lstm=args.bi_lstm)  # if you train bi_lstm model in pretrained model, this time you need to use bi_lstm = True
                print("loading {}".format(args.target_dir + os.sep + model_path))
                chainer.serializers.load_npz(args.target_dir + os.sep + model_path, model)
                if args.gpu >= 0:
                    chainer.cuda.get_device_from_id(args.gpu).use()
                    model.to_gpu(args.gpu)
                    if with_crf:
                        model.open_crf.to_cpu()
            trainer_keyword_pattern = re.compile(".*?((\d+_)+)_*")
            matcher = trainer_keyword_pattern.match(model_path)
            assert matcher
            trainer_keyword = matcher.group(1)[:-1]
            target_dict[trainer_keyword] = model
    if len(target_dict) == 0:
        print("error , no pretrained npz file in {}".format(args.target_dir))
        return
    if args.check:
        check_pretrained_model_match_file(target_dict, args.test)
    with chainer.no_backprop_mode():
        test_data = GraphDataset(directory=args.test, attrib_size=args.hidden_size, global_dataset=dataset,
                                 need_s_rnn=need_srnn, need_cache_factor_graph=False, target_dict=target_dict)  # if there is one file that use structural_rnn, all the pact_structure need structural_rnn
        test_iter = chainer.iterators.SerialIterator(test_data, 1, shuffle=False, repeat=False)
        gpu = args.gpu if not use_crf else -1
        print('using gpu :{}'.format(gpu))
        chainer.config.train = False
        au_evaluator = ActionUnitRoILabelSplitEvaluator(test_iter, target_dict, device=gpu, database=args.database)
        observation = au_evaluator.evaluate()
        with open(args.target_dir + os.sep + "evaluation_result.json", "w") as file_obj:
            file_obj.write(json.dumps(observation, indent=4, separators=(',', ': ')))
            file_obj.flush()
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=0,
                        help='GPU ID (negative value indicates CPU)'
                        )  # open_crf layer only works for CPU mode
    parser.add_argument('--step_size',
                        '-ss',
                        type=int,
                        default=3000,
                        help='step_size for lr exponential')
    parser.add_argument('--gradclip',
                        '-c',
                        type=float,
                        default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--snapshot',
                        '-snap',
                        type=int,
                        default=1,
                        help='snapshot epochs for save checkpoint')
    parser.add_argument('--test_mode',
                        action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.add_argument(
        '--valid',
        '-v',
        default='',
        help='validate directory path contains validate txt file')
    parser.add_argument(
        "--test",
        '-tt',
        default='test',
        help='Test directory path contains test txt file to evaluation')
    parser.add_argument('--train',
                        '-t',
                        default="train",
                        help='Train directory path contains train txt file')
    parser.add_argument('--database',
                        default="BP4D",
                        help='database to train for')
    parser.add_argument("--stop_eps",
                        '-eps',
                        type=float,
                        default=1e-4,
                        help="f - old_f < eps ==> early stop")
    parser.add_argument('--with_crf',
                        '-crf',
                        action='store_true',
                        help='whether to use open crf layer')
    parser.add_argument('--lr', '-l', type=float, default=0.01)
    parser.add_argument('--crf_lr', type=float, default=0.1)
    parser.add_argument(
        '--hidden_size',
        type=int,
        default=1024,
        help=
        "if you want to use open-crf layer, this hidden_size is node dimension input of open-crf"
    )
    parser.add_argument('--eval_mode',
                        action='store_true',
                        help='whether to evaluate the model')
    parser.add_argument("--num_attrib",
                        type=int,
                        default=2048,
                        help="number of dimension of each node feature")
    parser.add_argument("--num_geometry_feature",
                        type=int,
                        default=4,
                        help="number of dimension of each node feature")
    parser.add_argument("--proc_num",
                        '-proc',
                        type=int,
                        default=1,
                        help="process number of dataset reader")
    parser.add_argument("--need_cache_graph",
                        "-ng",
                        action="store_true",
                        help="whether to cache factor graph to LRU cache")
    parser.add_argument("--bi_lstm",
                        '-bilstm',
                        action='store_true',
                        help="Use bi_lstm as basic component of S-RNN")
    parser.add_argument("--resume",
                        action="store_true",
                        help="whether to load npz pretrained file")
    parser.add_argument(
        "--exclude",
        action="store_true",
        help="exclude the has already pretrained model file")  #FIXME 临时添加
    parser.add_argument('--train_edge',
                        default="all",
                        help="train temporal/all to comparision")
    parser.set_defaults(test=False)
    args = parser.parse_args()
    assert not args.resume == args.exclude  # conflict with args.resume
    print_interval = 1, 'iteration'
    val_interval = 5, 'iteration'

    adaptive_AU_database(args.database)
    train_str = args.train
    if train_str[-1] == "/":
        train_str = train_str[:-1]
    trainer_keyword = os.path.basename(train_str)
    assert "_" in trainer_keyword

    # for the StructuralRNN constuctor need first frame factor graph_backup
    dataset = GlobalDataSet(num_attrib=args.num_attrib,
                            num_geo_attrib=args.num_geometry_feature,
                            train_edge=args.train_edge)
    file_name = list(
        filter(lambda e: e.endswith(".txt"), os.listdir(args.train)))[0]
    sample = dataset.load_data(
        args.train + os.sep + file_name
    )  # we load first sample for construct S-RNN, it must passed to constructor argument
    crf_pact_structure = CRFPackageStructure(
        sample, dataset, num_attrib=args.hidden_size
    )  # 只读取其中的一个视频的第一帧,由于node个数相对稳定,因此可以construct RNN
    # 因为我们用多分类的hinge loss,所以需要num_label是来自于2进制形式的label数+1(+1代表全0)
    model = StructuralRNNPlus(crf_pact_structure,
                              in_size=dataset.num_attrib_type,
                              out_size=dataset.num_label,
                              hidden_size=args.hidden_size,
                              with_crf=args.with_crf,
                              use_bi_lstm=args.bi_lstm)

    # note that the following code attrib_size will be used by open_crf for parameter number, thus we cannot pass dataset.num_attrib_type!
    train_data = GraphDataset(args.train,
                              attrib_size=args.hidden_size,
                              global_dataset=dataset,
                              need_s_rnn=True,
                              need_cache_factor_graph=args.need_cache_graph,
                              get_geometry_feature=False)  # train 传入文件夹

    train_iter = chainer.iterators.SerialIterator(train_data,
                                                  1,
                                                  shuffle=True,
                                                  repeat=True)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.structural_rnn.to_gpu(args.gpu)

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))
    updater = BPTTUpdater(train_iter, optimizer, int(args.gpu))
    early_stop = EarlyStoppingTrigger(args.epoch,
                                      key='main/loss',
                                      eps=float(args.stop_eps))
    if args.with_crf:
        trainer = chainer.training.Trainer(updater,
                                           stop_trigger=(args.epoch, "epoch"),
                                           out=args.out)
        model.open_crf.W.update_rule.hyperparam.lr = float(args.crf_lr)
        model.open_crf.to_cpu()
    else:
        trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'),
                                           out=args.out)

    interval = (1, 'iteration')
    if args.test_mode:
        chainer.config.train = False
    trainer.extend(chainer.training.extensions.observe_lr(),
                   trigger=print_interval)
    trainer.extend(
        chainer.training.extensions.PrintReport([
            'iteration',
            'epoch',
            'elapsed_time',
            'lr',
            'main/loss',
            "main/accuracy",
            "opencrf_val/main/hit",  # "opencrf_validation/main/U_hit",
            "opencrf_val/main/miss",  # "opencrf_validation/main/U_miss",
            "opencrf_val/main/F1",  # "opencrf_validation/main/U_F1"
            'opencrf_val/main/accuracy',
        ]),
        trigger=print_interval)

    log_name = "s_rnn_plus_{}.log".format(trainer_keyword)
    trainer.extend(
        chainer.training.extensions.LogReport(trigger=interval,
                                              log_name=log_name))
    # trainer.extend(chainer.training.extensions.ProgressBar(update_interval=1, training_length=(args.epoch, 'epoch')))
    optimizer_snapshot_name = "{0}_{1}_srnn_plus_optimizer.npz".format(
        trainer_keyword, args.database)
    model_snapshot_name = "{0}_{1}_srnn_plus{2}_model.npz".format(
        trainer_keyword, args.database, "_crf" if args.with_crf else "")
    trainer.extend(chainer.training.extensions.snapshot_object(
        optimizer, filename=optimizer_snapshot_name),
                   trigger=(args.snapshot, 'epoch'))

    trainer.extend(chainer.training.extensions.snapshot_object(
        model, filename=model_snapshot_name),
                   trigger=(args.snapshot, 'epoch'))
    trainer.extend(chainer.training.extensions.ExponentialShift('lr', 0.7),
                   trigger=(5, "epoch"))

    if args.resume and os.path.exists(args.out + os.sep + model_snapshot_name):
        print("loading model_snapshot_name to model")
        chainer.serializers.load_npz(args.out + os.sep + model_snapshot_name,
                                     model)
    elif args.exclude and os.path.exists(args.out + os.sep +
                                         model_snapshot_name):
        print("pretrained file has already exists, exit program")
        return
    if args.resume and os.path.exists(args.out + os.sep +
                                      optimizer_snapshot_name):
        print("loading optimizer_snapshot_name to optimizer")
        chainer.serializers.load_npz(
            args.out + os.sep + optimizer_snapshot_name, optimizer)

    if chainer.training.extensions.PlotReport.available():
        trainer.extend(chainer.training.extensions.PlotReport(
            ['main/loss'],
            file_name="{}_train_loss.png".format(trainer_keyword)),
                       trigger=(100, "iteration"))
        trainer.extend(chainer.training.extensions.PlotReport(
            ['opencrf_val/F1', 'opencrf_val/accuracy'],
            file_name="{}_val_f1.png".format(trainer_keyword)),
                       trigger=val_interval)

    # au_evaluator = ActionUnitEvaluator(iterator=validate_iter, target=model, device=-1, database=args.database,
    #                                    data_info_path=os.path.dirname(args.train) + os.sep + "data_info.json")
    # trainer.extend(au_evaluator, trigger=val_interval, name='au_validation')
    # trainer.extend(Evaluator(validate_iter, model, converter=convert, device=-1), trigger=val_interval,
    #                name='accu_validation')
    # if args.with_crf:
    if args.valid:
        valid_data = GraphDataset(args.valid,
                                  attrib_size=args.hidden_size,
                                  global_dataset=dataset,
                                  need_s_rnn=True,
                                  need_cache_factor_graph=args.need_cache_graph
                                  )  # attrib_size控制open-crf层的weight长度
        validate_iter = chainer.iterators.SerialIterator(valid_data,
                                                         1,
                                                         shuffle=False,
                                                         repeat=False)
        crf_evaluator = OpenCRFEvaluator(iterator=validate_iter,
                                         target=model,
                                         device=args.gpu)
        trainer.extend(crf_evaluator, trigger=val_interval, name="opencrf_val")

    trainer.run()