Ejemplo n.º 1
0
def run_live(args):
    print(args)
    args['text_language'] = args['language']
    args.pop('language', None)
    # load args
    with open(os.path.join(MM_CODES_PATH, 'config.json'), 'r') as fp:
        config = json.load(fp)
    model_args = config["MODELS"][args['modelName']]['args'][
        args['datasetName']]
    dataset_args = config["DATASETS"][args['datasetName']]
    if "need_data_aligned" in model_args and model_args['need_data_aligned']:
        dataset_args = dataset_args['aligned']
    else:
        dataset_args = dataset_args['unaligned']
    # load args
    args = Storage(dict(**model_args, **dataset_args, **args))
    args.device = 'cpu'
    print(args)
    # load model
    setup_seed(args.seed)
    model = AMIO(args)
    pretrained_model_path = os.path.join(MODEL_TMP_SAVE,
                                         args.pre_trained_model)
    # if os.path.exists(pretrained_model_path):
    model.load_state_dict(torch.load(pretrained_model_path))

    # data pre
    dp = MLive(args.live_working_dir, args.transcript, args.text_language)
    dp.dataPre()
    text, audio, video = dp.getEmbeddings(args.seq_lens, args.feature_dims)
    text = torch.Tensor(text).unsqueeze(0)
    audio = torch.Tensor(audio).unsqueeze(0)
    video = torch.Tensor(video).unsqueeze(0)
    print(text.size(), audio.size(), video.size())

    if args.need_normalized:
        audio = torch.mean(audio, dim=1, keepdims=True)
        video = torch.mean(video, dim=1, keepdims=True)
    # predict
    model.eval()
    with torch.no_grad():
        outputs = model(text, audio, video)

    if 'tasks' not in args:
        args.tasks = 'M'

    annotation_dict = {v: k for k, v in args.annotations.items()}

    ret = {}
    for m in args.tasks:
        cur_output = outputs[m].detach().squeeze().numpy()
        cur_output = np.clip(cur_output, 1e-8, 10)
        cur_output = softmax(cur_output)
        sentiment = np.argmax(cur_output)
        # json cannot serialize float32
        probs = {
            annotation_dict[i]: str(round(v, 4))
            for i, v in enumerate(cur_output)
        }

        ret[m] = {'model': args.modelName, "probs": probs}

    print(ret)
    return ret
Ejemplo n.º 2
0
"""
测试MULT中使用的ATTENTION机制的特点,是否是强变强、其它变弱?
"""
import torch

from config.config_run import Config
from data.load_data import MMDataLoader
from models.AMIO import AMIO
from run import parse_args
from trains.ATIO import ATIO

model_path = "/home/zhuchuanbo/paper_code/results/model_saves/mult-sims-M.pth"

# 进行参数配置
configs = Config(parse_args()).get_config()
device = torch.device('cuda:%d' % configs.gpu_ids[0])
configs.device = device

# 定义并且加载模型
dataloader = MMDataLoader(configs)
model = AMIO(configs).to(device)

model.load_state_dict(torch.load(model_path))
model.eval()

atio = ATIO().get_train(configs)
results = atio.do_test(model, dataloader['test'], mode="TEST")

import torchvision.models as models