コード例 #1
0
def evaluate(config):
    """
    :param config: helper.configure, Configure Object
    """
    # loading corpus and generate vocabulary
    corpus_vocab = Vocab(config, min_freq=5, max_size=50000)

    # get data
    _, _, test_loader = data_loaders(config, corpus_vocab)

    # build up model
    hiagm = HiAGM(config,
                  corpus_vocab,
                  model_type=config.model.type,
                  model_mode='TRAIN')
    hiagm.to(config.train.device_setting.device)
    # define training objective & optimizer
    criterion = ClassificationLoss(
        os.path.join(config.data.data_dir, config.data.hierarchy),
        corpus_vocab.v2i['label'],
        recursive_penalty=config.train.loss.recursive_regularization.penalty,
        recursive_constraint=config.train.loss.recursive_regularization.flag)
    optimize = set_optimizer(config, hiagm)

    model_checkpoint = config.train.checkpoint.dir
    dir_list = os.listdir(model_checkpoint)
    assert len(dir_list), "No model file in checkpoint directory!!"
    assert os.path.isfile(os.path.join(model_checkpoint, config.test.best_checkpoint)), \
        "The predefined checkpoint file does not exist."
    model_file = os.path.join(model_checkpoint, config.test.best_checkpoint)
    logger.info('Loading Previous Checkpoint...')
    logger.info('Loading from {}'.format(model_file))
    _, config = load_checkpoint(model_file=model_file,
                                model=hiagm,
                                config=config)
    # get epoch trainer
    trainer = Trainer(model=hiagm,
                      criterion=criterion,
                      optimizer=optimize,
                      vocab=corpus_vocab,
                      config=config)
    hiagm.eval()
    # set origin log
    trainer.eval(test_loader, -1, 'TEST')
    return
コード例 #2
0
def train(config):
    """
    :param config: helper.configure, Configure Object
    """
    # loading corpus and generate vocabulary
    corpus_vocab = Vocab(config, min_freq=5, max_size=50000)

    # get data
    train_loader, dev_loader, test_loader = data_loaders(config, corpus_vocab)

    # build up model
    htcinfomax = HTCInfoMax(config, corpus_vocab, model_mode='TRAIN')
    htcinfomax.to(config.train.device_setting.device)
    # define training objective & optimizer
    criterion = ClassificationLoss(
        os.path.join(config.data.data_dir, config.data.hierarchy),
        corpus_vocab.v2i['label'],
        recursive_penalty=config.train.loss.recursive_regularization.penalty,
        recursive_constraint=config.train.loss.recursive_regularization.flag)
    optimize = set_optimizer(config, htcinfomax)

    # get epoch trainer
    trainer = Trainer(model=htcinfomax,
                      criterion=criterion,
                      optimizer=optimize,
                      vocab=corpus_vocab,
                      config=config)

    # set origin log
    best_epoch = [-1, -1]
    best_performance = [0.0, 0.0]
    model_checkpoint = config.train.checkpoint.dir
    model_name = config.model.type
    wait = 0
    if not os.path.isdir(model_checkpoint):
        os.mkdir(model_checkpoint)
    else:
        # loading previous checkpoint
        dir_list = os.listdir(model_checkpoint)
        dir_list.sort(key=lambda fn: os.path.getatime(
            os.path.join(model_checkpoint, fn)))
        latest_model_file = ''
        for model_file in dir_list[::-1]:
            if model_file.startswith('best'):
                continue
            else:
                latest_model_file = model_file
                break
        if os.path.isfile(os.path.join(model_checkpoint, latest_model_file)):
            logger.info('Loading Previous Checkpoint...')
            logger.info('Loading from {}'.format(
                os.path.join(model_checkpoint, latest_model_file)))
            best_performance, config = load_checkpoint(model_file=os.path.join(
                model_checkpoint, latest_model_file),
                                                       model=htcinfomax,
                                                       config=config,
                                                       optimizer=optimize)
            logger.info(
                'Previous Best Performance---- Micro-F1: {}%, Macro-F1: {}%'.
                format(best_performance[0], best_performance[1]))

    # train
    for epoch in range(config.train.start_epoch, config.train.end_epoch):
        start_time = time.time()
        trainer.train(train_loader, epoch)
        trainer.eval(train_loader, epoch, 'TRAIN')
        performance = trainer.eval(dev_loader, epoch, 'DEV')
        # saving best model and check model
        if not (performance['micro_f1'] >= best_performance[0]
                or performance['macro_f1'] >= best_performance[1]):
            wait += 1
            if wait % config.train.optimizer.lr_patience == 0:
                logger.warning(
                    "Performance has not been improved for {} epochs, updating learning rate"
                    .format(wait))
                trainer.update_lr()
            if wait == config.train.optimizer.early_stopping:
                logger.warning(
                    "Performance has not been improved for {} epochs, stopping train with early stopping"
                    .format(wait))
                break

        if performance['micro_f1'] > best_performance[0]:
            wait = 0
            logger.info('Improve Micro-F1 {}% --> {}%'.format(
                best_performance[0], performance['micro_f1']))
            best_performance[0] = performance['micro_f1']
            best_epoch[0] = epoch
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                }, os.path.join(model_checkpoint, 'best_micro_' + model_name))
        if performance['macro_f1'] > best_performance[1]:
            wait = 0
            logger.info('Improve Macro-F1 {}% --> {}%'.format(
                best_performance[1], performance['macro_f1']))
            best_performance[1] = performance['macro_f1']
            best_epoch[1] = epoch
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                }, os.path.join(model_checkpoint, 'best_macro_' + model_name))

        if epoch % 10 == 1:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model_type': config.model.type,
                    'state_dict': htcinfomax.state_dict(),
                    'best_performance': best_performance,
                    'optimizer': optimize.state_dict()
                },
                os.path.join(model_checkpoint,
                             model_name + '_epoch_' + str(epoch)))

        logger.info('Epoch {} Time Cost {} secs.'.format(
            epoch,
            time.time() - start_time))

    best_epoch_model_file = os.path.join(model_checkpoint,
                                         'best_micro_' + model_name)
    if os.path.isfile(best_epoch_model_file):
        load_checkpoint(best_epoch_model_file,
                        model=htcinfomax,
                        config=config,
                        optimizer=optimize)
        trainer.eval(test_loader, best_epoch[0], 'TEST')

    best_epoch_model_file = os.path.join(model_checkpoint,
                                         'best_macro_' + model_name)
    if os.path.isfile(best_epoch_model_file):
        load_checkpoint(best_epoch_model_file,
                        model=htcinfomax,
                        config=config,
                        optimizer=optimize)
        trainer.eval(test_loader, best_epoch[1], 'TEST')

    return
コード例 #3
0
import torch
import onnx
from backbone.shufflenet_v2 import shufflenet_v2_x0_5
from helper.utils import load_checkpoint

if __name__ == '__main__':
    model = shufflenet_v2_x0_5()
    model = load_checkpoint(
        model,
        '/home/can/AI_Camera/face_clasification/model/checkpoint_149_0.010453527558476049.tar'
    )
    model.eval()
    # ##################export###############
    output_onnx = "model_filter.onnx"
    print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
    input_names = ["input0"]
    output_names = ["output0"]
    inputs = torch.randn(1, 3, 112, 112)
    torch_out = torch.onnx._export(
        model,
        inputs,
        output_onnx,
        verbose=True,
        input_names=input_names,
        output_names=output_names,
        example_outputs=True,  # to show sample output dimension
        keep_initializers_as_inputs=True,  # to avoid error _Map_base::at
        opset_version=
        7,  # need to change to 11, to deal with tensorflow fix_size input
        # dynamic_axes={
        #     "input0": [2, 3],
コード例 #4
0
ファイル: run_images.py プロジェクト: cannguyen275/isFace

def get_file_ffolder(path):
    files_name = os.listdir(path)
    file_paths = [os.path.join(path, file_name) for file_name in files_name]
    return file_paths


device = torch.device('cpu')
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = shufflenet_v2_x0_5()
# pretrained = torch.load("model/checkpoint_149_0.043281020134628756.tar")
model = load_checkpoint(model, 'model/checkpoint_149_0.010453527558476049.tar')
model.eval()
model = model.to(device)
images_path = get_file_ffolder("Your Image folder")
classes = ['face', 'nonface']
for path in images_path:
    # path = '/home/can/AI_Camera/pose_estimation/hinh_2.jpg'
    print(path)
    img = cv2.imread(path)
    # img = cv2.resize(img, (112, 112))
    img = img[..., ::-1]  # RGB
    img = Image.fromarray(img, 'RGB')  # RGB
    img = transformer(img)
    img = img.unsqueeze(0)
    img = img.to(device)
    a = time.time()