def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer): lr_per_mb = [1.0]*30 + [0.1]*30 + [0.01]*20 + [0.001] l2_reg_weight = 0.0001 # adjust LR with minibatch size if minibatch_size != 256: for i in range(0, len(lr_per_mb)): lr_per_mb[i] *= minibatch_size / 256 # Set learning parameters lr_schedule = learning_rate_schedule(lr_per_mb, epoch_size=epoch_size, unit=UnitType.minibatch) mm_schedule = momentum_schedule(0.9) local_learner = nesterov(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight) # learner object if block_size != None and num_quantization_bits != 32: raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.") if block_size != None: learner = block_momentum_distributed_learner(local_learner, block_size=block_size) else: learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up) return Trainer(network['output'], (network['ce'], network['errs']), learner, progress_printer)
def create_trainer(self): try: p = self.output.parameters # Three of four parameters are learned by block_momentum_distributed_learner. bmd_learner = cntk.block_momentum_distributed_learner( cntk.momentum_sgd( [p[0], p[1], p[2]], cntk.learning_parameter_schedule(0.0001), cntk.momentum_as_time_constant_schedule(1000)), block_size=1000, block_learning_rate=0.01, block_momentum_as_time_constant=1000) # New API to mark which learner is to use for metric aggregaion. bmd_learner.set_as_metric_aggregator() # The last parameter is learned by the data_parallel_distributed_learner. momentum_schedule = cntk.momentum_schedule_per_sample( 0.9990913221888589) lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007) dpd_learner = cntk.data_parallel_distributed_learner( cntk.momentum_sgd([p[3]], lr_per_sample, momentum_schedule, True)) comm_rank = cntk.distributed.Communicator.rank() self.trainer = cntk.Trainer( self.output, (self.ce, self.err), [bmd_learner, dpd_learner], [ cntk.logging.ProgressPrinter( freq=progress_freq, tag="Training", rank=comm_rank) ]) except RuntimeError: self.trainer = None return
def create_distributed_learner(self, mode, config): local_learner = C.sgd(self.z.parameters, C.learning_parameter_schedule_per_sample(0.01)) try: if mode == 'data_parallel': if config is None: config = DataParallelConfig(num_quantization_bits=32, distributed_after=0) learner = C.data_parallel_distributed_learner( local_learner, num_quantization_bits=config.num_quantization_bits, distributed_after=config.distributed_after) elif mode == 'block_momentum': if config is None: # the default config to match data parallel SGD config = BlockMomentumConfig( block_momentum_as_time_constant=0, block_learning_rate=1, block_size=NUM_WORKERS, distributed_after=0) learner = C.block_momentum_distributed_learner( local_learner, block_momentum_as_time_constant=config. block_momentum_as_time_constant, block_learning_rate=config.block_learning_rate, block_size=config.block_size, distributed_after=config.distributed_after) else: learner = local_learner except RuntimeError: learner = None return learner
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer): if network['name'] == 'resnet20': lr_per_mb = [1.0]*80+[0.1]*40+[0.01] elif network['name'] == 'resnet110': lr_per_mb = [0.1]*1+[1.0]*80+[0.1]*40+[0.01] else: return RuntimeError("Unknown model name!") momentum_time_constant = -minibatch_size/np.log(0.9) l2_reg_weight = 0.0001 # Set learning parameters lr_per_sample = [lr/minibatch_size for lr in lr_per_mb] lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size, unit=UnitType.sample) mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant) # learner object if block_size != None and num_quantization_bits != 32: raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.") local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight = l2_reg_weight) if block_size != None: learner = block_momentum_distributed_learner(local_learner, block_size=block_size) else: learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up) return Trainer(network['output'], (network['ce'], network['pe']), learner, progress_printer)
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer): if network['name'] == 'resnet20': lr_per_mb = [1.0]*80+[0.1]*40+[0.01] elif network['name'] == 'resnet110': lr_per_mb = [0.1]*1+[1.0]*80+[0.1]*40+[0.01] else: return RuntimeError("Unknown model name!") momentum_time_constant = -minibatch_size/np.log(0.9) l2_reg_weight = 0.0001 # Set learning parameters lr_per_sample = [lr/minibatch_size for lr in lr_per_mb] lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size, unit=UnitType.sample) mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant) # learner object if block_size != None and num_quantization_bits != 32: raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.") local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight = l2_reg_weight) if block_size != None: learner = block_momentum_distributed_learner(local_learner, block_size=block_size) else: learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up) return Trainer(network['output'], (network['ce'], network['pe']), learner, progress_printer)
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer): lr_per_mb = [0.1] # [1.0]*30 + [0.1]*30 + [0.01]*20 + [0.001] l2_reg_weight = 0.0001 # adjust LR with minibatch size #if minibatch_size != 256: # for i in range(0, len(lr_per_mb)): # lr_per_mb[i] *= minibatch_size / 256 # Set learning parameters lr_schedule = learning_rate_schedule(lr_per_mb, epoch_size=epoch_size, unit=UnitType.minibatch) mm_schedule = momentum_schedule(0.9) local_learner = nesterov(network['output'].parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight) # learner object if block_size != None and num_quantization_bits != 32: raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.") if block_size != None: learner = block_momentum_distributed_learner(local_learner, block_size=block_size) else: learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up) return Trainer(network['output'], (network['ce'], network['errs']), learner, progress_printer)
def __init__(self): self.input_dim = 40000 self.embed_dim = 100 self.batch_size = 20 i = C.input_variable((self.input_dim, ), is_sparse=True) self.p = C.parameter(shape=(self.input_dim, self.embed_dim), init=1) o = C.times(i, self.p) z = C.reduce_sum(o) learner = C.data_parallel_distributed_learner( C.sgd( z.parameters, C.learning_rate_schedule(0.01, unit=C.learners.UnitType.sample))) self.trainer = C.Trainer(z, (z, None), learner, [])
def create_trainer(self): try: lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007) p = self.output.parameters # Three of four parameters are learned by first data_parallel_distributed_learner. learner1 = cntk.data_parallel_distributed_learner( cntk.sgd([p[0], p[1], p[2]], lr_per_sample)) # New API to mark which learner is to use for metric aggregaion. learner1.set_as_metric_aggregator() # The last parameter is learned by another data_parallel_distributed_learner. learner2 = cntk.data_parallel_distributed_learner( cntk.sgd([p[3]], lr_per_sample)) comm_rank = cntk.distributed.Communicator.rank() self.trainer = cntk.Trainer( self.output, (self.ce, self.err), [learner1, learner2], [ cntk.logging.ProgressPrinter( freq=progress_freq, tag="Training", rank=comm_rank) ]) except RuntimeError: self.trainer = None return
def create_trainer(self): try: lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007) learner = cntk.data_parallel_distributed_learner( cntk.sgd(self.output.parameters, lr_per_sample)) comm_rank = cntk.distributed.Communicator.rank() self.trainer = cntk.Trainer( self.output, (self.ce, self.err), [learner], [ cntk.logging.ProgressPrinter( freq=progress_freq, tag="Training", rank=comm_rank) ]) except RuntimeError: self.trainer = None return
def create_distributed_learner(self, mode, config): local_learner = C.sgd(self.z.parameters, C.learning_parameter_schedule_per_sample(0.01)) try: if mode == 'data_parallel': if config is None: config = DataParallelConfig(num_quantization_bits=32, distributed_after=0) learner = C.data_parallel_distributed_learner(local_learner, num_quantization_bits=config.num_quantization_bits, distributed_after=config.distributed_after) elif mode == 'block_momentum': if config is None: # the default config to match data parallel SGD config = BlockMomentumConfig(block_momentum_as_time_constant=0, block_learning_rate=1, block_size=NUM_WORKERS, distributed_after=0) learner = C.block_momentum_distributed_learner(local_learner, block_momentum_as_time_constant=config.block_momentum_as_time_constant, block_learning_rate=config.block_learning_rate, block_size=config.block_size, distributed_after=config.distributed_after) else: learner = local_learner except RuntimeError: learner = None return learner
def run(self, iterator, mode='train'): report = dict() input_var = C.ops.input_variable(np.prod(iterator.iamge_shape), np.float32) label_var = C.ops.input_variable(iterator.batch_size, np.float32) model = self.model(input_var,) ce = C.losses.cross_entropy_with_softmax(model, label_var) pe = C.metrics.classification_error(model, label_var) z = cnn(input_var) learner = C.learners.momentum_sgd(z.parameters, self.lr_schedule, self.m_schedule) if self.is_parallel: distributed_learner = \ C.data_parallel_distributed_learner(learner=learner, distributed_after=0) progress_printer = \ C.logging.ProgressPrinter(tag='Training', num_epochs=iterator.niteration) if self.is_parallel: trainer = C.Trainer(z, (ce, pe), distributed_learner, progress_printer) else: trainer = C.Trainer(z, (ce, pe), learner, progress_printer) for idx, (x, t) in enumerate(iterator): total_s = time.perf_counter() trainer.train_minibatch({input_var : x, label_var : t}) forward_s = time.perf_counter() forward_e = time.perf_counter() backward_s = time.perf_counter() backward_e = time.perf_counter() total_e = time.perf_counter() report[idx] = dict( forward=forward_e - forward_s, backward=backward_e - backward_s, total=total_e - total_s ) return report
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config minibatch_size = training_config['minibatch_size'] max_epochs = training_config['max_epochs'] epoch_size = training_config['epoch_size'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=log_file, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner, num_quantization_bits=1) trainer = C.Trainer(z, (loss, None), learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = {'best_val_err': 100, 'best_since': 0, 'val_since': 0} if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) epoch_stat['best_val_err'] = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 trainer.save_checkpoint(model_file) for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) for epoch in range(max_epochs): num_seq = 0 with tqdm(total=epoch_size, ncols=32, smoothing=0.1) as progress_bar: while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break else: progress_bar.update( trainer.previous_minibatch_sample_count) if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): training_config = importlib.import_module(config_file).training_config # config for using multi GPUs if training_config['multi_gpu']: gpu_pad = training_config['gpu_pad'] gpu_cnt = training_config['gpu_cnt'] my_rank = C.Communicator.rank() my_gpu_id = (my_rank + gpu_pad) % gpu_cnt print("rank = " + str(my_rank) + ", using gpu " + str(my_gpu_id) + " of " + str(gpu_cnt)) C.try_set_default_device(C.gpu(my_gpu_id)) else: C.try_set_default_device(C.gpu(0)) # outputs while training normal_log = os.path.join(data_path, training_config['logdir'], log_file) # tensorboard files' dir tensorboard_logdir = os.path.join(data_path, training_config['logdir'], log_file) polymath = PolyMath(config_file) z, loss = polymath.model() max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=normal_log, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] # add tensorboard writer for visualize tensorboard_writer = C.logging.TensorBoardProgressWriter( freq=10, log_dir=tensorboard_logdir, rank=C.Communicator.rank(), model=z) progress_writers.append(tensorboard_writer) lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies_info = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, p))) dummies_info[dummies[-1].output] = (p.name, p.shape) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) trainer = C.Trainer(z, (loss, None), learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = { 'best_val_err': 100, 'best_since': 0, 'val_since': 0, 'record_num': 0 } if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) #after restore always re-evaluate epoch_stat['best_val_err'] = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath, config_file) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model( os.path.join(data_path, training_config['val_data']), model, polymath, config_file) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 os.system("ls -la >> log.log") os.system("ls -la ./Models >> log.log") save_flag = True fail_cnt = 0 while save_flag: if fail_cnt > 100: print("ERROR: failed to save models") break try: trainer.save_checkpoint(model_file) epoch_stat['record_num'] += 1 record_file = os.path.join( model_path, str(epoch_stat['record_num']) + '-' + model_name) trainer.save_checkpoint(record_file) save_flag = False except: fail_cnt = fail_cnt + 1 for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count # print_para_info(dummy, dummies_info) if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def train(i2w, data_path, model_path, log_file, config_file, restore=True, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [ C.logging.ProgressPrinter(num_epochs=max_epochs, freq=log_freq, tag='Training', log_to_file=log_file, metric_is_pct=False, rank=C.Communicator.rank(), gen_heartbeat=gen_heartbeat) ] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) # learner = C.adadelta(z.parameters, lr) learner = C.fsadagrad( z.parameters, #apply the learning rate as if it is a minibatch of size 1 lr, momentum=C.momentum_schedule( 0.9366416204111472, minibatch_size=training_config['minibatch_size']), gradient_clipping_threshold_per_sample=2.3, gradient_clipping_with_truncation=True) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) trainer = C.Trainer(z, loss, learner, progress_writers) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(z.outputs + loss.outputs) #this is for validation only epoch_stat = {'best_val_err': 1000, 'best_since': 0, 'val_since': 0} print(restore, os.path.isfile(model_file)) # if restore and os.path.isfile(model_file): if restore and os.path.isfile(model_file): z.restore(model_file) #after restore always re-evaluate #TODO replace with rougel with external script(possibly) #epoch_stat['best_val_err'] = validate_model(i2w, os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value #TODO replace with rougel with external script(possibly) val_err = validate_model( i2w, os.path.join(data_path, training_config['val_data']), model, polymath) #if epoch_stat['best_val_err'] > val_err: # epoch_stat['best_val_err'] = val_err # epoch_stat['best_since'] = 0 # trainer.save_checkpoint(model_file) # for p in trainer.model.parameters: # p.value = temp[p.uid] #else: # epoch_stat['best_since'] += 1 # if epoch_stat['best_since'] > training_config['stop_after']: # return False z.save(model_file) epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True init_pointer_importance = polymath.pointer_importance if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config[ 'distributed_after']: data = mb_source.next_minibatch( minibatch_size * C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break print('Before Pointer_importance:', polymath.pointer_importance) if polymath.pointer_importance > 0.1 * init_pointer_importance: polymath.pointer_importance = polymath.pointer_importance * 0.9 print('Pointer_importance:', polymath.pointer_importance) else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config[ 'minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False): polymath = PolyMath(config_file) z, loss = polymath.model() training_config = importlib.import_module(config_file).training_config max_epochs = training_config['max_epochs'] log_freq = training_config['log_freq'] progress_writers = [C.logging.ProgressPrinter( num_epochs = max_epochs, freq = log_freq, tag = 'Training', log_to_file = log_file, rank = C.Communicator.rank(), gen_heartbeat = gen_heartbeat)] lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None) ema = {} dummies = [] for p in z.parameters: ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid) ema[p.uid] = ema_p dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p))) dummy = C.combine(dummies) learner = C.adadelta(z.parameters, lr) if C.Communicator.num_workers() > 1: learner = C.data_parallel_distributed_learner(learner) tensorboard_writer = TensorBoardProgressWriter(freq=10, log_dir='log', model=z) trainer = C.Trainer(z, (loss, None), learner, tensorboard_writer) if profiling: C.debugging.start_profiler(sync_gpu=True) train_data_file = os.path.join(data_path, training_config['train_data']) train_data_ext = os.path.splitext(train_data_file)[-1].lower() model_file = os.path.join(model_path, model_name) model = C.combine(list(z.outputs) + [loss.output]) label_ab = argument_by_name(loss, 'ab') epoch_stat = { 'best_val_err' : 100, 'best_since' : 0, 'val_since' : 0} if restore and os.path.isfile(model_file): trainer.restore_from_checkpoint(model_file) #after restore always re-evaluate epoch_stat['best_val_err'] = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath) def post_epoch_work(epoch_stat): trainer.summarize_training_progress() epoch_stat['val_since'] += 1 if epoch_stat['val_since'] == training_config['val_interval']: epoch_stat['val_since'] = 0 temp = dict((p.uid, p.value) for p in z.parameters) for p in trainer.model.parameters: p.value = ema[p.uid].value val_err = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath) if epoch_stat['best_val_err'] > val_err: epoch_stat['best_val_err'] = val_err epoch_stat['best_since'] = 0 trainer.save_checkpoint(model_file) for p in trainer.model.parameters: p.value = temp[p.uid] else: epoch_stat['best_since'] += 1 if epoch_stat['best_since'] > training_config['stop_after']: return False if profiling: C.debugging.enable_profiler() return True if train_data_ext == '.ctf': mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath) minibatch_size = training_config['minibatch_size'] # number of samples epoch_size = training_config['epoch_size'] for epoch in range(max_epochs): num_seq = 0 while True: if trainer.total_number_of_samples_seen >= training_config['distributed_after']: data = mb_source.next_minibatch(minibatch_size*C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank()) else: data = mb_source.next_minibatch(minibatch_size, input_map=input_map) trainer.train_minibatch(data) num_seq += trainer.previous_minibatch_sample_count dummy.eval() if num_seq >= epoch_size: break if not post_epoch_work(epoch_stat): break else: if train_data_ext != '.tsv': raise Exception("Unsupported format") minibatch_seqs = training_config['minibatch_seqs'] # number of sequences for epoch in range(max_epochs): # loop over epochs tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers()) minibatch_count = 0 for data in tsv_reader: if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank(): trainer.train_minibatch(data) # update model with it dummy.eval() minibatch_count += 1 if not post_epoch_work(epoch_stat): break if profiling: C.debugging.stop_profiler()