def test(test_data, model_path, model_file, config_file): polymath = PolyMath(config_file) model = C.load_model(os.path.join(model_path, model_file if model_file else model_name)) begin_logits = model.outputs[0] end_logits = model.outputs[1] loss = C.as_composite(model.outputs[2].owner) begin_prediction = C.sequence.input_variable(1, sequence_axis=begin_logits.dynamic_axes[1], needs_gradient=True) end_prediction = C.sequence.input_variable(1, sequence_axis=end_logits.dynamic_axes[1], needs_gradient=True) best_span_score = symbolic_best_span(begin_prediction, end_prediction) predicted_span = C.layers.Recurrence(C.plus)(begin_prediction - C.sequence.past_value(end_prediction)) batch_size = 32 # in sequences misc = {'rawctx':[], 'ctoken':[], 'answer':[], 'uid':[]} tsv_reader = create_tsv_reader(loss, test_data, polymath, batch_size, 1, is_test=True, misc=misc) results = {} with open('{}_out.json'.format(model_file), 'w', encoding='utf-8') as json_output: for data in tsv_reader: out = model.eval(data, outputs=[begin_logits,end_logits,loss], as_numpy=False) g = best_span_score.grad({begin_prediction:out[begin_logits], end_prediction:out[end_logits]}, wrt=[begin_prediction,end_prediction], as_numpy=False) other_input_map = {begin_prediction: g[begin_prediction], end_prediction: g[end_prediction]} span = predicted_span.eval((other_input_map)) for seq, (raw_text, ctokens, answer, uid) in enumerate(zip(misc['rawctx'], misc['ctoken'], misc['answer'], misc['uid'])): seq_where = np.argwhere(span[seq])[:,0] span_begin = np.min(seq_where) span_end = np.max(seq_where) predict_answer = get_answer(raw_text, ctokens, span_begin, span_end) results['query_id'] = int(uid) results['answers'] = [predict_answer] json.dump(results, json_output) json_output.write("\n") misc['rawctx'] = [] misc['ctoken'] = [] misc['answer'] = [] misc['uid'] = []
def streaming_inference(model_path, model_file, config_file, port="8889", is_test=1): polymath = PolyMath(config_file) model = C.load_model( os.path.join(model_path, model_file if model_file else model_name)) begin_logits = model.outputs[0] end_logits = model.outputs[1] loss = C.as_composite(model.outputs[2].owner) begin_prediction = C.sequence.input_variable( 1, sequence_axis=begin_logits.dynamic_axes[1], needs_gradient=True) end_prediction = C.sequence.input_variable( 1, sequence_axis=end_logits.dynamic_axes[1], needs_gradient=True) best_span_score = symbolic_best_span(begin_prediction, end_prediction) predicted_span = C.layers.Recurrence( C.plus)(begin_prediction - C.sequence.past_value(end_prediction)) batch_size = 1 # in sequences misc = {'rawctx': [], 'ctoken': [], 'answer': [], 'uid': []} Flag = True context = zmq.Context() socket = context.socket(zmq.REP) socket.bind("tcp://*:8889") while True: message = socket.recv() question_str, context_str = pickle.loads(message) line = "1102432\tDESCRIPTION\t" + context_str + "\t" + question_str data = streaming_create_tsv_reader(loss, line, polymath, batch_size, 1, is_test=True, misc=misc) out = model.eval(data, outputs=[begin_logits, end_logits, loss], as_numpy=False) g = best_span_score.grad( { begin_prediction: out[begin_logits], end_prediction: out[end_logits] }, wrt=[begin_prediction, end_prediction], as_numpy=False) other_input_map = { begin_prediction: g[begin_prediction], end_prediction: g[end_prediction] } span = predicted_span.eval((other_input_map)) #print("just before for {}".format(misc['ctoken'])) seq, raw_text, ctokens, answer, uid = 0, misc['rawctx'], misc[ 'ctoken'], misc['answer'], misc['uid'] #print("just AFTER for {}".format(ctokens)) seq_where = np.argwhere(span[seq])[:, 0] span_begin = np.min(seq_where) span_end = np.max(seq_where) #print("before predict") predict_answer = get_answer(raw_text[0], ctokens[0], span_begin, span_end) # results['query_id'] = int(uid) result = (question_str, predict_answer) socket.send(pickle.dumps(result))
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config minibatch_size = training_config['minibatch_size'] max_epochs = training_config['max_epochs'] epoch_size = training_config['epoch_size'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=log_file, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner, num_quantization_bits=1) trainer = C.Trainer(z, (loss, None), learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = {'best_val_err': 100, 'best_since': 0, 'val_since': 0} if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) epoch_stat['best_val_err'] = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 trainer.save_checkpoint(model_file) for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) for epoch in range(max_epochs): num_seq = 0 with tqdm(total=epoch_size, ncols=32, smoothing=0.1) as progress_bar: while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break else: progress_bar.update( trainer.previous_minibatch_sample_count) if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): training_config = importlib.import_module(config_file).training_config # config for using multi GPUs if training_config['multi_gpu']: gpu_pad = training_config['gpu_pad'] gpu_cnt = training_config['gpu_cnt'] my_rank = C.Communicator.rank() my_gpu_id = (my_rank + gpu_pad) % gpu_cnt print("rank = " + str(my_rank) + ", using gpu " + str(my_gpu_id) + " of " + str(gpu_cnt)) C.try_set_default_device(C.gpu(my_gpu_id)) else: C.try_set_default_device(C.gpu(0)) # outputs while training normal_log = os.path.join(data_path, training_config['logdir'], log_file) # tensorboard files' dir tensorboard_logdir = os.path.join(data_path, training_config['logdir'], log_file) polymath = PolyMath(config_file) z, loss = polymath.model() max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=normal_log, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] # add tensorboard writer for visualize tensorboard_writer = C.logging.TensorBoardProgressWriter( freq=10, log_dir=tensorboard_logdir, rank=C.Communicator.rank(), model=z) progress_writers.append(tensorboard_writer) lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies_info = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, p))) dummies_info[dummies[-1].output] = (p.name, p.shape) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) trainer = C.Trainer(z, (loss, None), learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = { 'best_val_err': 100, 'best_since': 0, 'val_since': 0, 'record_num': 0 } if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) #after restore always re-evaluate epoch_stat['best_val_err'] = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath, config_file) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath, config_file) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 os.system("ls -la >> log.log") os.system("ls -la ./Models >> log.log") save_flag = True fail_cnt = 0 while save_flag: if fail_cnt > 100: print("ERROR: failed to save models") break try: trainer.save_checkpoint(model_file) epoch_stat['record_num'] += 1 record_file = os.path.join( model_path, str(epoch_stat['record_num']) + '-' + model_name) trainer.save_checkpoint(record_file) save_flag = False except: fail_cnt = fail_cnt + 1 for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count # print_para_info(dummy, dummies_info) if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def streaming_inference(line, model_path, model_file, config_file, port="8889", is_test=1): polymath = PolyMath(config_file) model = C.load_model( os.path.join(model_path, model_file if model_file else model_name)) begin_logits = model.outputs[0] end_logits = model.outputs[1] loss = C.as_composite(model.outputs[2].owner) begin_prediction = C.sequence.input_variable( 1, sequence_axis=begin_logits.dynamic_axes[1], needs_gradient=True) end_prediction = C.sequence.input_variable( 1, sequence_axis=end_logits.dynamic_axes[1], needs_gradient=True) best_span_score = symbolic_best_span(begin_prediction, end_prediction) predicted_span = C.layers.Recurrence( C.plus)(begin_prediction - C.sequence.past_value(end_prediction)) batch_size = 1 # in sequences misc = {'rawctx': [], 'ctoken': [], 'answer': [], 'uid': []} Flag = True while Flag: # try: if True: data = streaming_create_tsv_reader(loss, line, polymath, batch_size, 1, is_test=True, misc=misc) out = model.eval(data, outputs=[begin_logits, end_logits, loss], as_numpy=False) g = best_span_score.grad( { begin_prediction: out[begin_logits], end_prediction: out[end_logits] }, wrt=[begin_prediction, end_prediction], as_numpy=False) other_input_map = { begin_prediction: g[begin_prediction], end_prediction: g[end_prediction] } span = predicted_span.eval((other_input_map)) print("just before for {}".format(misc['ctoken'])) seq, raw_text, ctokens, answer, uid = 0, misc['rawctx'], misc[ 'ctoken'], misc['answer'], misc['uid'] print("just AFTER for {}".format(ctokens)) seq_where = np.argwhere(span[seq])[:, 0] span_begin = np.min(seq_where) span_end = np.max(seq_where) print("before predict") predict_answer = get_answer(raw_text[0], ctokens[0], span_begin, span_end) # results['query_id'] = int(uid) result = predict_answer print(result) # except: # import pdb # pdb.set_trace() Flag = False
def test(i2w, test_data, model_path, model_file, config_file): #C.try_set_default_device(C.cpu()) polymath = PolyMath(config_file) print(test_data, model_path, model_file, model_name) print(os.path.join(model_path, model_file)) model = C.Function.load( os.path.join(model_path, model_file if model_file else model_name)) print(model) output = model.outputs[1] # loss = model.outputs[5] start_logits = model.outputs[2] end_logits = model.outputs[3] context = model.outputs[4] # loss = model.outputs[5] root = C.as_composite(output.owner) begin_prediction = C.sequence.input_variable( 1, sequence_axis=start_logits.dynamic_axes[1], needs_gradient=True) end_prediction = C.sequence.input_variable( 1, sequence_axis=end_logits.dynamic_axes[1], needs_gradient=True) predicted_span = C.layers.Recurrence( C.plus)(begin_prediction - C.sequence.past_value(end_prediction)) best_span_score = symbolic_best_span(begin_prediction, end_prediction) batch_size = 1 # in sequences misc = {'rawctx': [], 'ctoken': [], 'answer': [], 'uid': []} tsv_reader = create_tsv_reader(root, test_data, polymath, batch_size, 1, is_test=True, misc=misc) results = {} with open('{}_out.json'.format(model_file), 'w', encoding='utf-8') as json_output: for data in tsv_reader: out = model.eval( data, outputs=[output, start_logits, end_logits, context], as_numpy=False) g = best_span_score.grad( { begin_prediction: out[start_logits], end_prediction: out[end_logits] }, wrt=[begin_prediction, end_prediction], as_numpy=False) other_input_map = { begin_prediction: g[begin_prediction], end_prediction: g[end_prediction] } span = predicted_span.eval((other_input_map)) for seq, (raw_text, ctokens, answer, uid) in enumerate( zip(misc['rawctx'], misc['ctoken'], misc['answer'], misc['uid'])): # g = best_span_score.grad({begin_prediction:out[start_logits], end_prediction:out[end_logits]}, wrt=[begin_prediction,end_prediction], as_numpy=False) # other_input_map = {begin_prediction: g[begin_prediction], end_prediction: g[end_prediction]} # span = predicted_span.eval((other_input_map)) seq_where = np.argwhere(span[seq])[:, 0] span_begin = np.min(seq_where) span_end = np.max(seq_where) predict_answer = get_answer(raw_text, ctokens, span_begin, span_end) # span_out = np.asarray(span).reshape(-1).tolist() # context_o = np.asarray(out[context]).reshape(-1).tolist() # predict_answer = [] # for i in range(len(span_out)): # if(span_out[i]==1): # predict_answer.append(context_o[i]) print(predict_answer) final_answer = format_output_sequences( np.asarray(out[output].as_sequences()).reshape(-1), predict_answer, i2w, polymath) results['query_id'] = int(uid) results['answers'] = [final_answer] print(results) json.dump(results, json_output) json_output.write("\n") misc['rawctx'] = [] misc['ctoken'] = [] misc['answer'] = [] misc['uid'] = []
def train(i2w, data_path, model_path, log_file, config_file, restore=True, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=log_file, metric_is_pct=False, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) # learner = C.adadelta(z.parameters, lr) learner = C.fsadagrad( z.parameters, #apply the learning rate as if it is a minibatch of size 1 lr, momentum=C.momentum_schedule( 0.9366416204111472, minibatch_size=training_config['minibatch_size']), gradient_clipping_threshold_per_sample=2.3, gradient_clipping_with_truncation=True) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) trainer = C.Trainer(z, loss, learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(z.outputs + loss.outputs) #this is for validation only epoch_stat = {'best_val_err': 1000, 'best_since': 0, 'val_since': 0} print(restore, os.path.isfile(model_file)) # if restore and os.path.isfile(model_file): if restore and os.path.isfile(model_file): z.restore(model_file) #after restore always re-evaluate #TODO replace with rougel with external script(possibly) #epoch_stat['best_val_err'] = validate_model(i2w, os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value #TODO replace with rougel with external script(possibly) val_err = validate_model( i2w, os.path.join(data_path, training_config['val_data']), model, polymath) #if epoch_stat['best_val_err'] > val_err: # epoch_stat['best_val_err'] = val_err # epoch_stat['best_since'] = 0 # trainer.save_checkpoint(model_file) # for p in trainer.model.parameters: # p.value = temp[p.uid] #else: # epoch_stat['best_since'] += 1 # if epoch_stat['best_since'] > training_config['stop_after']: # return False z.save(model_file) epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True init_pointer_importance = polymath.pointer_importance if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break print('Before Pointer_importance:', polymath.pointer_importance) if polymath.pointer_importance > 0.1 * init_pointer_importance: polymath.pointer_importance = polymath.pointer_importance * 0.9 print('Pointer_importance:', polymath.pointer_importance) else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [C.logging.ProgressPrinter( num_epochs = max_epochs, freq = log_freq, tag = 'Training', log_to_file = log_file, rank = C.Communicator.rank(), gen_heartbeat = gen_heartbeat)] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) tensorboard_writer = TensorBoardProgressWriter(freq=10, log_dir='log', model=z) trainer = C.Trainer(z, (loss, None), learner, tensorboard_writer) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = { 'best_val_err' : 100, 'best_since' : 0, 'val_since' : 0} if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) #after restore always re-evaluate epoch_stat['best_val_err'] = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 trainer.save_checkpoint(model_file) for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config['distributed_after']: data = mb_source.next_minibatch(minibatch_size*C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config['minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()