Ejemplo n.º 1
0
 def __init__(self, param):
     self.param = param
     self.read_data = ReadBatchData(param)
     print "initialized read data"
     self.interpreter = Interpreter(self.read_data.program_type_vocab,
                                    self.read_data.argument_type_vocab)
     print "initialized interpreter"
     self.train_data = []
     if not isinstance(param['train_data_file'], list):
         self.training_files = [param['train_data_file']]
     else:
         self.training_files = param['train_data_file']
         random.shuffle(self.training_files)
     print 'Training data loaded'
     sys.stdout.flush()
     self.valid_data = []
     if not isinstance(param['valid_data_file'], list):
         self.valid_files = [param['valid_data_file']]
     else:
         self.valid_files = param['valid_data_file']
     for file in self.valid_files:
         self.valid_data.extend(pkl.load(open(file)))
     if not os.path.exists(param['model_dir']):
         os.mkdir(param['model_dir'])
     self.model_file = os.path.join(param['model_dir'], "best_model")
     with tf.Graph().as_default():
         self.model = NPI(
             param, self.read_data.none_argtype_index,
             self.read_data.num_argtypes, self.read_data.num_progs,
             self.read_data.max_arguments, self.read_data.rel_index,
             self.read_data.type_index, self.read_data.wikidata_rel_embed,
             self.read_data.wikidata_type_embed,
             self.read_data.vocab_init_embed,
             self.read_data.program_to_argtype,
             self.read_data.program_to_targettype)
         self.model.create_placeholder()
         self.action_sequence, self.program_probs, self.gradients = self.model.reinforce(
         )
         self.train_op = self.model.train()
         print 'model created'
         sys.stdout.flush()
         self.saver = tf.train.Saver()
         init = tf.initialize_all_variables()
         self.sess = tf.Session(
         )  #tf_debug.LocalCLIDebugWrapperSession(tf.Session())
         if len(glob.glob(os.path.join(param['model_dir'], '*'))) > 0:
             print "best model exists .. restoring from there "
             self.saver.restore(self.sess, self.model_file)
         else:
             print "initializing fresh variables"
             self.sess.run(init)
def evaluate_card_pattern_matching():
    """
    Load NPI Model from Checkpoint, and initialize REPL, for interactive carry-addition.
    """
    # Load Data
    with open(TEST_PATH, 'rb') as f:
        data = pickle.load(f)

    # Initialize Card Pattern Matching Core
    print('Initializing Card Pattern Matching Core!')
    core = CardPatternMatchingCore()

    # Initialize NPI Model
    npi = NPI(core, CONFIG, LOG_PATH)

    with tf.Session() as sess:
        # Restore from Checkpoint
        saver = tf.train.Saver()
        saver.restore(sess, CKPT_PATH)

        # Run REPL
        repl(sess, npi, data)
Ejemplo n.º 3
0
    def __init__(self, param):
        np.random.seed(1)
        torch.manual_seed(999)
        #if torch.cuda.is_available(): torch.cuda.manual_seed_all(999)
        self.param = param
        self.run_interpreter = True
        self.run_validation = False
        self.generate_data = False
        self.param = param
        if 'npi_core_dim' not in self.param:
            self.param['npi_core_dim'] = self.param['hidden_dim']
        if 'cell_dim' not in self.param:
            self.param['cell_dim'] = self.param['hidden_dim']
        self.questype_program_dict = {
            'verify': '',
            'simple': 'gen_set',
            'logical': 'gen_set',
            'quantitative': 'gen_map1',
            'quantitative count': 'gen_set,gen_map1',
            'comparative': 'gen_set,gen_map1',
            'comparative count': 'gen_set,gen_map1',
            'simple,logical': 'gen_set'
        }
        if not self.generate_data and os.path.exists(self.param['model_dir'] +
                                                     '/model_data.pkl'):
            self.pickled_train_data = pkl.load(
                open(self.param['model_dir'] + '/model_data.pkl'))
        else:
            self.pickled_train_data = {}
        self.starting_epoch = 0
        self.starting_overall_step_count = 0
        self.starting_validation_reward_overall = 0
        self.starting_validation_reward_topbeam = 0
        if 'dont_look_back_attention' not in self.param:
            self.param['dont_look_back_attention'] = False
        if 'concat_query_npistate' not in self.param:
            self.param['concat_query_npistate'] = False
        if 'query_attention' not in self.param:
            self.param['query_attention'] = False
        if self.param['dont_look_back_attention']:
            self.param['query_attention'] = True
        if 'single_reward_function' not in self.param:
            self.param['single_reward_function'] = False
        if 'terminate_prog' not in self.param:
            self.param['terminate_prog'] = False
            terminate_prog = False
        else:
            terminate_prog = self.param['terminate_prog']
        if 'none_decay' not in self.param:
            self.param['none_decay'] = 0

        if 'train_mode' not in self.param:
            self.param['train_mode'] = 'reinforce'
        self.qtype_wise_batching = self.param['questype_wise_batching']
        self.read_data = ReadBatchData(param)
        print "initialized read data"
        if 'quantitative' in self.param[
                'question_type'] or 'comparative' in self.param[
                    'question_type']:
            if 'relaxed_reward_till_epoch' in self.param:
                relaxed_reward_till_epoch = self.param[
                    'relaxed_reward_till_epoch']
            else:
                self.param['relaxed_reward_till_epoch'] = [-1, -1]
                relaxed_reward_till_epoch = [-1, -1]
        else:
            self.param['relaxed_reward_till_epoch'] = [-1, -1]
            relaxed_reward_till_epoch = [-1, -1]
        if 'params_turn_on_after' not in self.param:
            self.param['params_turn_on_after'] = 'epoch'
        if self.param['params_turn_on_after'] != 'epoch' and self.param[
                'params_turn_on_after'] != 'batch':
            raise Exception('params_turn_on_after should be epoch or batch')
        if 'print' in self.param:
            self.printing = self.param['print']
        else:
            self.param['print'] = False
            self.printing = True
        if 'prune_beam_type_mismatch' not in self.param:
            self.param['prune_beam_type_mismatch'] = 0
        if 'prune_after_epoch_no.' not in self.param:
            self.param['prune_after_epoch_no.'] = [
                self.param['max_epochs'], 1000000
            ]
        if self.param['question_type'] == 'verify':
            boolean_reward_multiplier = 1
        else:
            boolean_reward_multiplier = 0.1
        if 'print_valid_freq' not in self.param:
            self.param['print_valid_freq'] = self.param['print_train_freq']
        if 'valid_freq' not in self.param:
            self.param['valid_freq'] = 100
        if 'unused_var_penalize_after_epoch' not in self.param:
            self.param['unused_var_penalize_after_epoch'] = [
                self.param['max_epochs'], 1000000
            ]
        unused_var_penalize_after_epoch = self.param[
            'unused_var_penalize_after_epoch']
        if 'epoch_for_feasible_program_at_last_step' not in self.param:
            self.param['epoch_for_feasible_program_at_last_step'] = [
                self.param['max_epochs'], 1000000
            ]
        if 'epoch_for_biasing_program_sample_with_target' not in self.param:
            self.param['epoch_for_biasing_program_sample_with_target'] = [
                self.param['max_epochs'], 1000000
            ]
        if 'epoch_for_biasing_program_sample_with_last_variable' not in self.param:
            self.param[
                'epoch_for_biasing_program_sample_with_last_variable'] = [
                    self.param['max_epochs'], 100000
                ]
        if 'use_var_key_as_onehot' not in self.param:
            self.param['use_var_key_as_onehot'] = False
        if 'reward_function' not in self.param:
            reward_func = "jaccard"
            self.param['reward_function'] = "jaccard"
        else:
            reward_func = self.param['reward_function']
        if 'relaxed_reward_strict' not in self.param:
            relaxed_reward_strict = False
            self.param['relaxed_reward_strict'] = relaxed_reward_strict
        else:
            relaxed_reward_strict = self.param['relaxed_reward_strict']
        if param['parallel'] == 1:
            raise Exception(
                'Need to fix the intermediate rewards for parallelly executing interpreter'
            )
        for k, v in param.items():
            print 'PARAM: ', k, ':: ', v
        print 'loaded params '
        self.train_data = []
        if os.path.isdir(param['train_data_file']):
            self.training_files = [
                param['train_data_file'] + '/' + x
                for x in os.listdir(param['train_data_file'])
                if x.endswith('.pkl')
            ]
        elif not isinstance(param['train_data_file'], list):
            self.training_files = [param['train_data_file']]
        else:
            self.training_files = param['train_data_file']
            random.shuffle(self.training_files)
        self.valid_data = []
        if os.path.isdir(param['valid_data_file']):
            self.valid_files = [
                param['valid_data_file'] + '/' + x
                for x in os.listdir(param['valid_data_file'])
                if x.endswith('.pkl')
            ]
        elif not isinstance(param['valid_data_file'], list):
            self.valid_files = [param['valid_data_file']]
        else:
            self.valid_files = param['valid_data_file']
        for file in self.valid_files:
            temp = pkl.load(open(file))
            temp = self.remove_bad_data(temp)
            temp = self.add_data_id(temp)
            self.valid_data.extend(temp)
        if self.qtype_wise_batching:
            self.valid_data_map = self.read_data.get_data_per_questype(
                self.valid_data)
            self.valid_batch_size_types = self.get_batch_size_per_type(
                self.valid_data_map)
            self.n_valid_batches = int(
                math.ceil(
                    float(sum([len(x)
                               for x in self.valid_data_map.values()]))) /
                float(self.param['batch_size']))
        else:
            self.n_valid_batches = int(
                math.ceil(
                    float(len(self.valid_data)) /
                    float(self.param['batch_size'])))

        if not os.path.exists(param['model_dir']):
            os.mkdir(param['model_dir'])
        self.model_file = os.path.join(param['model_dir'], param['model_file'])
        learning_rate = param['learning_rate']
        start = time.time()
        self.model = NPI(self.param, self.read_data.none_argtype_index, self.read_data.num_argtypes, \
                         self.read_data.num_progs, self.read_data.max_arguments, \
                         self.read_data.rel_index, self.read_data.type_index, \
                         self.read_data.wikidata_rel_embed, self.read_data.wikidata_type_embed, \
                         self.read_data.vocab_init_embed, self.read_data.program_to_argtype, \
                         self.read_data.program_to_targettype)
        self.checkpoint_prefix = os.path.join(param['model_dir'],
                                              param['model_file'])
        if os.path.exists(self.checkpoint_prefix):
            self.model.load_state_dict(torch.load(self.checkpoint_prefix))
            fr = open(self.param['model_dir'] + '/metadata.txt').readlines()
            self.starting_epoch = int(fr[0].split(' ')[1].strip())
            self.starting_overall_step_count = int(fr[1].split(' ')[1].strip())
            self.starting_validation_reward_overall = float(
                fr[2].split(' ')[1].strip())
            self.starting_validation_reward_topbeam = float(
                fr[3].split(' ')[1].strip())
            print 'restored model'
        end = time.time()
        if torch.cuda.is_available():
            self.model.cuda()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=learning_rate,
                                          betas=[0.9, 0.999],
                                          weight_decay=1e-5)
        print self.model
        print 'model created in ', (end - start), 'seconds'

        self.interpreter = Interpreter(self.param['wikidata_dir'], self.param['num_timesteps'], \
                                       self.read_data.program_type_vocab, self.read_data.argument_type_vocab, self.printing, terminate_prog, relaxed_reward_strict, reward_function = reward_func, boolean_reward_multiplier = boolean_reward_multiplier, relaxed_reward_till_epoch=relaxed_reward_till_epoch, unused_var_penalize_after_epoch=unused_var_penalize_after_epoch)
        if self.param['parallel'] == 1:
            self.InterpreterProxy, self.InterpreterProxyListener = proxy.createProxy(
                self.interpreter)
            self.interpreter.parallel = 1
            self.lock = Lock()
        print "initialized interpreter"
Ejemplo n.º 4
0
    def __init__(self, param):
        np.random.seed(1)
        tf.set_random_seed(1)
        self.param = param
        if 'normalize_wrt_num_args' not in self.param:
            self.param['normalize_wrt_num_args'] = False
        if 'dont_look_back_attention' not in self.param:
            self.param['dont_look_back_attention'] = False
        if 'concat_query_npistate' not in self.param:
            self.param['concat_query_npistate'] = False
        if 'query_attention' not in self.param:
            self.param['query_attention'] = False
        if self.param['dont_look_back_attention']:
            self.param['query_attention'] = True
        if 'single_reward_function' not in self.param:
            self.param['single_reward_function'] = False
        if 'terminate_prog' not in self.param:
            self.param['terminate_prog'] = False
            terminate_prog = False
        else:
            terminate_prog = self.param['terminate_prog']
        if 'train_mode' not in self.param:
            self.param['train_mode'] = 'reinforce'
        self.qtype_wise_batching = self.param['questype_wise_batching']
        self.read_data = ReadBatchData(param)
        print "initialized read data"
        if 'quantitative' in self.param[
                'question_type'] or 'comparative' in self.param[
                    'question_type']:
            if 'relaxed_reward_till_epoch' in self.param:
                relaxed_reward_till_epoch = self.param[
                    'relaxed_reward_till_epoch']
            else:
                self.param['relaxed_reward_till_epoch'] = [-1, -1]
                relaxed_reward_till_epoch = [-1, -1]
        else:
            self.param['relaxed_reward_till_epoch'] = [-1, -1]
            relaxed_reward_till_epoch = [-1, -1]
        if 'params_turn_on_after' not in self.param:
            self.param['params_turn_on_after'] = 'epoch'
        if self.param['params_turn_on_after'] != 'epoch' and self.param[
                'params_turn_on_after'] != 'batch':
            raise Exception('params_turn_on_after should be epoch or batch')
        if 'print' in self.param:
            self.printing = self.param['print']
        else:
            self.param['print'] = False
            self.printing = True
        if 'prune_beam_type_mismatch' not in self.param:
            self.param['prune_beam_type_mismatch'] = 0
        if 'prune_after_epoch_no.' not in self.param:
            self.param['prune_after_epoch_no.'] = [
                self.param['max_epochs'], 1000000
            ]
        if self.param['question_type'] == 'verify':
            boolean_reward_multiplier = 1
        else:
            boolean_reward_multiplier = 0.1
        if 'none_decay' not in self.param:
            self.param['none_decay'] = 0
        if 'print_test_freq' not in self.param:
            self.param['print_test_freq'] = self.param['print_train_freq']
        if 'unused_var_penalize_after_epoch' not in self.param:
            self.param['unused_var_penalize_after_epoch'] = [
                self.param['max_epochs'], 1000000
            ]
        unused_var_penalize_after_epoch = self.param[
            'unused_var_penalize_after_epoch']
        if 'epoch_for_feasible_program_at_last_step' not in self.param:
            self.param['epoch_for_feasible_program_at_last_step'] = [
                self.param['max_epochs'], 1000000
            ]
        if 'epoch_for_biasing_program_sample_with_target' not in self.param:
            self.param['epoch_for_biasing_program_sample_with_target'] = [
                self.param['max_epochs'], 1000000
            ]
        if 'epoch_for_biasing_program_sample_with_last_variable' not in self.param:
            self.param[
                'epoch_for_biasing_program_sample_with_last_variable'] = [
                    self.param['max_epochs'], 100000
                ]
        if 'use_var_key_as_onehot' not in self.param:
            self.param['use_var_key_as_onehot'] = False
        if 'reward_function' not in self.param:
            reward_func = "jaccard"
            self.param['reward_function'] = "jaccard"
        else:
            reward_func = self.param['reward_function']
        if 'relaxed_reward_strict' not in self.param:
            relaxed_reward_strict = False
            self.param['relaxed_reward_strict'] = relaxed_reward_strict
        else:
            relaxed_reward_strict = self.param['relaxed_reward_strict']
        if param['parallel'] == 1:
            raise Exception(
                'Need to fix the intermediate rewards for parallelly executing interpreter'
            )
        for k, v in param.items():
            print 'PARAM: ', k, ':: ', v
        print 'loaded params '
        self.train_data = []
        if os.path.isdir(param['train_data_file']):
            self.training_files = [
                param['train_data_file'] + '/' + x
                for x in os.listdir(param['train_data_file'])
                if x.endswith('.pkl')
            ]
        elif not isinstance(param['train_data_file'], list):
            self.training_files = [param['train_data_file']]
        else:
            self.training_files = param['train_data_file']
            random.shuffle(self.training_files)
        sys.stdout.flush()
        self.test_data = []
        if os.path.isdir(param['test_data_file']):
            self.test_files = [
                param['test_data_file'] + '/' + x
                for x in os.listdir(param['test_data_file'])
                if x.endswith('.pkl')
            ]
        elif not isinstance(param['test_data_file'], list):
            self.test_files = [param['test_data_file']]
        else:
            self.test_files = param['test_data_file']
        for file in self.test_files:
            self.test_data.extend(pkl.load(open(file)))
        if self.qtype_wise_batching:
            self.test_data_map = self.read_data.get_data_per_questype(
                self.test_data)
            self.test_batch_size_types = self.get_batch_size_per_type(
                self.test_data_map)
            self.n_test_batches = int(
                math.ceil(
                    float(sum([len(x)
                               for x in self.test_data_map.values()]))) /
                float(self.param['batch_size']))
        else:
            self.n_test_batches = int(
                math.ceil(
                    float(len(self.test_data)) /
                    float(self.param['batch_size'])))

        if not os.path.exists(param['model_dir']):
            os.mkdir(param['model_dir'])
        self.model_file = os.path.join(param['model_dir'], param['model_file'])
        with tf.Graph().as_default():
            start = time.time()
            self.model = NPI(param, self.read_data.none_argtype_index, self.read_data.num_argtypes, \
                             self.read_data.num_progs, self.read_data.max_arguments, \
                             self.read_data.rel_index, self.read_data.type_index, \
                             self.read_data.wikidata_rel_embed, self.read_data.wikidata_type_embed, \
                             self.read_data.vocab_init_embed, self.read_data.program_to_argtype, \
                             self.read_data.program_to_targettype)
            self.model.create_placeholder()
            [self.action_sequence, self.program_probs, self.logProgramProb, self.Reward_placeholder, self.Relaxed_rewards_placeholder, \
             self.train_op, self.loss, self.beam_props, self.per_step_probs, self.IfPosIntermediateReward, \
             self.mask_IntermediateReward, self.IntermediateReward] = self.model.reinforce()
            #self.program_keys, self.program_embedding, self.word_embeddings, self.argtype_embedding, self.query_attention_h_mat = self.model.get_parameters()
            if param['Debug'] == 0:
                config = tf.ConfigProto()
                config.gpu_options.allow_growth = True
                self.sess = tf.Session(config=config)
                self.sess = tf.Session()
            else:
                self.sess = tf_debug.LocalCLIDebugWrapperSession(tf.Session())
            self.saver = tf.train.Saver()

            ckpt = tf.train.get_checkpoint_state(param['model_dir'])
            if ckpt and ckpt.model_checkpoint_path:
                print "best model exists in ", self.model_file, "... restoring from there "
                self.saver = tf.train.Saver()
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                print 'restored model'
            else:
                init = tf.global_variables_initializer()
                self.sess.run(init)
                print 'initialized model'
            end = time.time()
            print 'model created in ', (end - start), 'seconds'
            sys.stdout.flush()

        self.interpreter = Interpreter(self.param['wikidata_dir'], self.param['num_timesteps'], \
                                       self.read_data.program_type_vocab, self.read_data.argument_type_vocab, self.printing, terminate_prog, relaxed_reward_strict, reward_function = reward_func, boolean_reward_multiplier = boolean_reward_multiplier, relaxed_reward_till_epoch=relaxed_reward_till_epoch, unused_var_penalize_after_epoch=unused_var_penalize_after_epoch)
        if self.param['parallel'] == 1:
            self.InterpreterProxy, self.InterpreterProxyListener = proxy.createProxy(
                self.interpreter)
            self.interpreter.parallel = 1
            self.lock = Lock()
        print "initialized interpreter"