Esempio n. 1
0
    def __init__(self,
                 model_file=DEFAULT_MODEL_URL,
                 name='Sequicity'):
        """
        Sequicity initialization

        Args:
            model_file (str):
                trained model path or url. default="https://tatk-data.s3-ap-northeast-1.amazonaws.com/sequicity_multiwoz.zip"

        Example:
            sequicity = Sequicity()
        """
        super(Sequicity, self).__init__(name=name)
        config_file = DEFAULT_CONFIG_FILE
        c = json.load(open(config_file))
        cfg.init_handler(c['tsdf_init'])
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
            archive_file = cached_path(DEFAULT_ARCHIVE_FILE_URL)
            archive = zipfile.ZipFile(archive_file, 'r')
            print('unzip to', os.path.join(DEFAULT_DIRECTORY,'multiwoz/'))
            archive.extractall(os.path.join(DEFAULT_DIRECTORY,'multiwoz/'))
            archive.close()
        model_path = os.path.join(DEFAULT_DIRECTORY,c['tsdf_init']['model_path'])
        if not os.path.exists(model_path):
            model_dir = os.path.dirname(model_path)
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            print('Load from model_file param')
            print('down load data from', model_file)
            archive_file = cached_path(model_file)
            archive = zipfile.ZipFile(archive_file, 'r')
            print('unzip to', model_dir)
            archive.extractall(model_dir)
            archive.close()

        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        self.m = Model('multiwoz')
        self.m.count_params()
        self.m.load_model()
        self.init_session()
Esempio n. 2
0
def main(arg_mode=None, arg_model=None):

    parser = argparse.ArgumentParser()
    parser.add_argument('-mode')
    parser.add_argument('-model')
    # parser.add_argument('-cfg', nargs='*')
    parser.add_argument('-cfg')
    args = parser.parse_args()

    if arg_mode is not None:
        args.mode = arg_mode
    if arg_model is not None:
        args.model = arg_model

    # cfg.init_handler(args.model)
    c = json.load(open(args.cfg))
    cfg.init_handler(c['tsdf_init'])

    # if args.cfg:
    #     for pair in args.cfg:
    #         k, v = tuple(pair.split('='))
    #         dtype = type(getattr(cfg, k))
    #         if isinstance(None, dtype):
    #             raise ValueError()
    #         if dtype is bool:
    #             v = False if v == 'False' else True
    #         else:
    #             v = dtype(v)
    #         setattr(cfg, k, v)

    logging.debug(str(cfg))
    if cfg.cuda:
        logging.debug('Device: {}'.format(torch.cuda.current_device()))
    cfg.mode = args.mode

    torch.manual_seed(cfg.seed)
    torch.cuda.manual_seed(cfg.seed)
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)

    # m = Model(args.model.split('-')[-1])
    m = Model(args.model)
    m.count_params()
    if args.mode == 'train':
        m.load_glove_embedding()
        m.train()
    elif args.mode == 'adjust':
        m.load_model()
        m.train()
    elif args.mode == 'test':
        m.load_model()
        m.eval()
    elif args.mode == 'rl':
        m.load_model()
        m.reinforce_tune()
    elif args.mode == 'interact':
        m.load_model()
        m.interact()
    elif args.mode == 'load':
        m.load_model()
        return m