def _translate(self, process_id, input_item, get_sampler, sess): """ Actual translation (model sampling). """ # unpack input item attributes k = input_item.k x = input_item.batch alpha = input_item.normalization_alpha #max_ratio = input_item.max_ratio y_dummy = numpy.zeros(shape=(len(x),1)) x, x_mask, _, _ = util.prepare_data(x, y_dummy, self._options[0].factors, maxlen=None) sample = translate_utils.translate_batch( session=sess, sampler=get_sampler(k), x=x, x_mask=x_mask, max_translation_len=self._options[0].translation_maxlen, normalization_alpha=alpha) return sample
def full_sampler(replica, sampler, sess, config, x, x_mask, y, y_mask): """generate candidate sentences used for Minimum Risk Training Args: replica: inference models to do sampling x: (factor, len, batch_size) x_mask: (len, batch_size) y: (len, batch_size) y_mask: (len, batch_size) Returns: x, x_mask, y, y_mask are four lists containing the corresponding content of source-candidate sentence pairs, with shape: x: (factor, len, batch_size*sampleN) x_mask: (len, batch_size*sampleN) y: (len, batch_size*sampleN) y_mask: (len, batch_size*sampleN) y is a list of the corresponding references; index is a list of number indicating the starting point of different source sentences. """ sampleN = config.samplesN # set maximum number of tokens of sampled candidates dynamic_max_len = int(config.max_len_a * x_mask.shape[0] + config.max_len_b) max_translation_len = min(config.translation_maxlen, dynamic_max_len) if config.sample_way == 'beam_search': # split the minibatch into multiple sub-batches, and execute samplings for each sub-batch separately if config.max_sentences_of_sampling > 0: # number of split equals to batch_size / maximum accepted sentences for sampling (in a device) num_split = math.ceil(x_mask.shape[1] / config.max_sentences_of_sampling) # split the numpy array into a list of numpy array split_x = np.array_split(x, num_split, 2) split_x_mask = np.array_split(x_mask, num_split, 1) sample_and_score = [] # feed sub-batch into model to generate samples for i in range(len(split_x)): sample_and_score += translate_utils.translate_batch( sess, sampler, split_x[i], split_x_mask[i], max_translation_len, config.normalization_alpha) else: sample_and_score = translate_utils.translate_batch( sess, sampler, x, x_mask, max_translation_len, config.normalization_alpha) # sample_and_score: outer: batch_size, inner: sampleN elements(each represents a sample) # fetch samplings samples = [] for i, ss in enumerate(sample_and_score): samples.append([]) for (sample_seq, cost) in ss: samples[i].append(sample_seq.tolist()) # samples: list with shape (batch_size, sampleN, len), uneven # beam search sampling, no need to remove duplicate samples. # samples number of each batch (useless in beam sampling mode) index = [[0]] for i in range((len(samples))): index[0].append(index[0][i] + sampleN) elif config.sample_way == 'randomly_sample': samples = [] for i in range(x_mask.shape[1]): samples.append([]) if config.max_sentences_of_sampling > 0: num_split = math.ceil(x_mask.shape[1] / config.max_sentences_of_sampling) split_x = np.array_split(x, num_split, 2) split_x_mask = np.array_split(x_mask, num_split, 1) # set normalization_alpha to 0 for randomly sampling (no effect on sampled sentences) sample = translate_utils.translate_batch(sess, sampler, split_x[0], split_x_mask[0], max_translation_len, 0.0) for i in range(1, len(split_x)): tmp = translate_utils.translate_batch(sess, sampler, split_x[i], split_x_mask[i], max_translation_len, 0.0) sample = np.concatenate((sample, tmp)) else: sample = translate_utils.translate_batch(sess, sampler, x, x_mask, max_translation_len, 0.0) # sample: list: (batch_size, sampleN), each element is a tuple of (numpy array of a sampled sentence, its score) for i in range(len(samples)): for ss in sample[i]: samples[i].append(ss[0].tolist()) # samples: list with shape (batch_size, sampleN, len), uneven # remove duplicate samples for i in range(len(samples)): samples[i].sort() samples[i] = [s for s, _ in itertools.groupby(samples[i])] # remove the corresponding x and x_mask index = [] for i in range(len(samples)): index.append(len(samples[i])) for i in range(x_mask.shape[1]): x_new = np.repeat(x, index, axis=2) x_mask_new = np.repeat(x_mask, index, axis=1) # calculate the the number of remaining candidate samplings for each source sentence, # store the information in 'index' for the subsequent normalisation of distribution and calculation of # expected risk. index = [[0]] for i in range((len(samples))): index[0].append(index[0][i] + len(samples[i])) else: assert False # add reference in candidate sentences: # convert from time domain to batch domain y = list(map(list, zip(*y))) # y: batch_size X len y_mask = list(map(list, zip(*y_mask))) if config.mrt_reference: for i in range(len(samples)): # delete the pad of reference lenth = int(sum(y_mask[i])) y[i] = y[i][:lenth] # reference always at the first if y[i] not in samples[i]: samples[i].append(y[i]) samples[i].pop(-2) # add padding: (no specific padding token, just assign 0(<EOS>) and masked to avoid generating loss) # combine samples from different batches (decrease the outermost dimension) ss = [] for i in samples: ss += i samples = ss # samples: list with shape (batch_size*sampleN, len), uneven n_samples = len(samples) lengths_y = [len(s) for s in samples] maxlen_y = np.max(lengths_y) + 1 y_new = np.zeros((maxlen_y, n_samples)).astype('int64') y_mask_new = np.zeros((maxlen_y, n_samples)).astype('float32') for idx, s_y in enumerate(samples): y_new[:lengths_y[idx], idx] = s_y y_mask_new[:lengths_y[idx] + 1, idx] = 1. return x_new.tolist(), x_mask_new.tolist(), y_new.tolist( ), y_mask_new.tolist(), y, index
def train(config, sess): assert (config.prior_model != None and (tf.train.checkpoint_exists(os.path.abspath(config.prior_model))) or (config.map_decay_c==0.0)), \ "MAP training requires a prior model file: Use command-line option --prior_model" # Construct the graph, with one model replica per GPU num_gpus = len(tf_utils.get_available_gpus()) num_replicas = max(1, num_gpus) if config.loss_function == 'MRT': assert config.gradient_aggregation_steps == 1 assert config.max_sentences_per_device == 0, "MRT mode does not support sentence-based split" if config.max_tokens_per_device != 0: assert (config.samplesN * config.maxlen <= config.max_tokens_per_device), "need to make sure candidates of a sentence could be " \ "feed into the model" else: assert num_replicas == 1, "MRT mode does not support sentence-based split" assert (config.samplesN * config.maxlen <= config.token_batch_size), "need to make sure candidates of a sentence could be " \ "feed into the model" logging.info('Building model...') replicas = [] for i in range(num_replicas): device_type = "GPU" if num_gpus > 0 else "CPU" device_spec = tf.DeviceSpec(device_type=device_type, device_index=i) with tf.device(device_spec): with tf.variable_scope(tf.get_variable_scope(), reuse=(i>0)): if config.model_type == "transformer": model = TransformerModel(config) else: model = rnn_model.RNNModel(config) replicas.append(model) init = tf.zeros_initializer(dtype=tf.int32) global_step = tf.get_variable('time', [], initializer=init, trainable=False) if config.learning_schedule == "constant": schedule = learning_schedule.ConstantSchedule(config.learning_rate) elif config.learning_schedule == "transformer": schedule = learning_schedule.TransformerSchedule( global_step=global_step, dim=config.state_size, warmup_steps=config.warmup_steps) elif config.learning_schedule == "warmup-plateau-decay": schedule = learning_schedule.WarmupPlateauDecaySchedule( global_step=global_step, peak_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps, plateau_steps=config.plateau_steps) else: logging.error('Learning schedule type is not valid: {}'.format( config.learning_schedule)) sys.exit(1) if config.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=schedule.learning_rate, beta1=config.adam_beta1, beta2=config.adam_beta2, epsilon=config.adam_epsilon) else: logging.error('No valid optimizer defined: {}'.format(config.optimizer)) sys.exit(1) if config.summary_freq: summary_dir = (config.summary_dir if config.summary_dir is not None else os.path.abspath(os.path.dirname(config.saveto))) writer = tf.summary.FileWriter(summary_dir, sess.graph) else: writer = None updater = ModelUpdater(config, num_gpus, replicas, optimizer, global_step, writer) if config.exponential_smoothing > 0.0: smoothing = ExponentialSmoothing(config.exponential_smoothing) saver, progress = model_loader.init_or_restore_variables( config, sess, train=True) global_step.load(progress.uidx, sess) if config.sample_freq: random_sampler = RandomSampler( models=[replicas[0]], configs=[config], beam_size=1) if config.beam_freq or config.valid_script is not None: beam_search_sampler = BeamSearchSampler( models=[replicas[0]], configs=[config], beam_size=config.beam_size) #save model options write_config_to_json_file(config, config.saveto) text_iterator, valid_text_iterator = load_data(config) _, _, num_to_source, num_to_target = util.load_dictionaries(config) total_loss = 0. n_sents, n_words = 0, 0 last_time = time.time() logging.info("Initial uidx={}".format(progress.uidx)) # set epoch = 1 if print per-token-probability if config.print_per_token_pro: config.max_epochs = progress.eidx+1 for progress.eidx in range(progress.eidx, config.max_epochs): logging.info('Starting epoch {0}'.format(progress.eidx)) for source_sents, target_sents in text_iterator: if len(source_sents[0][0]) != config.factors: logging.error('Mismatch between number of factors in settings ({0}), and number in training corpus ({1})\n'.format(config.factors, len(source_sents[0][0]))) sys.exit(1) x_in, x_mask_in, y_in, y_mask_in = util.prepare_data( source_sents, target_sents, config.factors, maxlen=None) if x_in is None: logging.info('Minibatch with zero sample under length {0}'.format(config.maxlen)) continue write_summary_for_this_batch = config.summary_freq and ((progress.uidx % config.summary_freq == 0) or (config.finish_after and progress.uidx % config.finish_after == 0)) (factors, seqLen, batch_size) = x_in.shape output = updater.update( sess, x_in, x_mask_in, y_in, y_mask_in, num_to_target, write_summary_for_this_batch) if config.print_per_token_pro == False: total_loss += output else: # write per-token probability into the file f = open(config.print_per_token_pro, 'a') for pro in output: pro = str(pro) + '\n' f.write(pro) f.close() n_sents += batch_size n_words += int(numpy.sum(y_mask_in)) progress.uidx += 1 # Update the smoothed version of the model variables. # To reduce the performance overhead, we only do this once every # N steps (the smoothing factor is adjusted accordingly). if config.exponential_smoothing > 0.0 and progress.uidx % smoothing.update_frequency == 0: sess.run(fetches=smoothing.update_ops) if config.disp_freq and progress.uidx % config.disp_freq == 0: duration = time.time() - last_time disp_time = datetime.now().strftime('[%Y-%m-%d %H:%M:%S]') logging.info('{0} Epoch: {1} Update: {2} Loss/word: {3} Words/sec: {4} Sents/sec: {5}'.format(disp_time, progress.eidx, progress.uidx, total_loss/n_words, n_words/duration, n_sents/duration)) last_time = time.time() total_loss = 0. n_sents = 0 n_words = 0 if config.sample_freq and progress.uidx % config.sample_freq == 0: x_small = x_in[:, :, :10] x_mask_small = x_mask_in[:, :10] y_small = y_in[:, :10] samples = translate_utils.translate_batch( sess, random_sampler, x_small, x_mask_small, config.translation_maxlen, 0.0) assert len(samples) == len(x_small.T) == len(y_small.T), \ (len(samples), x_small.shape, y_small.shape) for xx, yy, ss in zip(x_small.T, y_small.T, samples): source = util.factoredseq2words(xx, num_to_source) target = util.seq2words(yy, num_to_target) sample = util.seq2words(ss[0][0], num_to_target) logging.info('SOURCE: {}'.format(source)) logging.info('TARGET: {}'.format(target)) logging.info('SAMPLE: {}'.format(sample)) if config.beam_freq and progress.uidx % config.beam_freq == 0: x_small = x_in[:, :, :10] x_mask_small = x_mask_in[:, :10] y_small = y_in[:,:10] samples = translate_utils.translate_batch( sess, beam_search_sampler, x_small, x_mask_small, config.translation_maxlen, config.normalization_alpha) assert len(samples) == len(x_small.T) == len(y_small.T), \ (len(samples), x_small.shape, y_small.shape) for xx, yy, ss in zip(x_small.T, y_small.T, samples): source = util.factoredseq2words(xx, num_to_source) target = util.seq2words(yy, num_to_target) logging.info('SOURCE: {}'.format(source)) logging.info('TARGET: {}'.format(target)) for i, (sample_seq, cost) in enumerate(ss): sample = util.seq2words(sample_seq, num_to_target) msg = 'SAMPLE {}: {} Cost/Len/Avg {}/{}/{}'.format( i, sample, cost, len(sample), cost/len(sample)) logging.info(msg) if config.valid_freq and progress.uidx % config.valid_freq == 0: if config.exponential_smoothing > 0.0: sess.run(fetches=smoothing.swap_ops) valid_ce = validate(sess, replicas[0], config, valid_text_iterator) sess.run(fetches=smoothing.swap_ops) else: valid_ce = validate(sess, replicas[0], config, valid_text_iterator) if (len(progress.history_errs) == 0 or valid_ce < min(progress.history_errs)): progress.history_errs.append(valid_ce) progress.bad_counter = 0 save_non_checkpoint(sess, saver, config.saveto) progress_path = '{0}.progress.json'.format(config.saveto) progress.save_to_json(progress_path) else: progress.history_errs.append(valid_ce) progress.bad_counter += 1 if progress.bad_counter > config.patience: logging.info('Early Stop!') progress.estop = True break if config.valid_script is not None: if config.exponential_smoothing > 0.0: sess.run(fetches=smoothing.swap_ops) score = validate_with_script(sess, beam_search_sampler) sess.run(fetches=smoothing.swap_ops) else: score = validate_with_script(sess, beam_search_sampler) need_to_save = (score is not None and (len(progress.valid_script_scores) == 0 or score > max(progress.valid_script_scores))) if score is None: score = 0.0 # ensure a valid value is written progress.valid_script_scores.append(score) if need_to_save: progress.bad_counter = 0 save_path = config.saveto + ".best-valid-script" save_non_checkpoint(sess, saver, save_path) write_config_to_json_file(config, save_path) progress_path = '{}.progress.json'.format(save_path) progress.save_to_json(progress_path) if config.save_freq and progress.uidx % config.save_freq == 0: saver.save(sess, save_path=config.saveto, global_step=progress.uidx) write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx)) progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx) progress.save_to_json(progress_path) if config.finish_after and progress.uidx % config.finish_after == 0: logging.info("Maximum number of updates reached") saver.save(sess, save_path=config.saveto, global_step=progress.uidx) write_config_to_json_file(config, "%s-%s" % (config.saveto, progress.uidx)) progress.estop=True progress_path = '{0}-{1}.progress.json'.format(config.saveto, progress.uidx) progress.save_to_json(progress_path) break if progress.estop: break