def train(Dataset, model, loss_fn, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, sep=','): if FLAGS.torch: if torch.cuda.is_available(): model.cuda() input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs batch_size = FLAGS.batch_size num_gpus = melt.num_gpus() if num_gpus > 1: assert False, 'Eager mode train currently not support for num gpus > 1' #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = batch_size if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', inputs) dataset = Dataset('train') num_examples = dataset.num_examples_per_epoch('train') num_all_examples = num_examples # if FLAGS.fold is not None: # valid_inputs = [x for x in all_inputs if x not in inputs] # else: # valid_inputs = gezi.list_files(FLAGS.valid_input) # logging.info('valid_inputs', valid_inputs) # if valid_inputs: # valid_dataset_ = Dataset('valid') # valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs) # valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True) # else: # valid_datsset = None # valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (len(inputs) / (len(inputs) + 1))) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None # if FLAGS.fold is not None: # if num_examples: # num_valid_examples = int(num_all_examples * (1 / (len(inputs) + 1))) # num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) # else: # num_valid_steps_per_epoch = None # else: # num_valid_examples = valid_dataset_.num_examples_per_epoch('valid') # num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None # test_inputs = gezi.list_files(FLAGS.test_input) # logging.info('test_inputs', test_inputs) # if test_inputs: # test_dataset_ = Dataset('test') # test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) # num_test_examples = test_dataset_.num_examples_per_epoch('test') # num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None # else: # test_dataset = None summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.model_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.model_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.model_dir + '/valid') writer = summary.create_file_writer(FLAGS.model_dir) writer_train = summary.create_file_writer(FLAGS.model_dir) writer_valid = summary.create_file_writer(FLAGS.model_dir) global_step = tf.train.get_or_create_global_step() learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None ckpt_dir = FLAGS.model_dir + '/ckpt' #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) logging.info('Latest checkpoint:', latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') if not FLAGS.torch: optimizer = melt.get_optimizer(FLAGS.optimizer)(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir checkpoint.restore(latest_checkpoint) start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint else 0 else: # TODO torch with learning rate adjust optimizer = torch.optim.Adamax(model.parameters(), lr=FLAGS.learning_rate) if latest_checkpoint: checkpoint = torch.load(latest_checkpoint + '.pyt') start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) model.eval() else: start_epoch = 0 if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) # TODO consider multiple gpu for torch iter = dataset.make_batch(batch_size, inputs, repeat=False, initializable=False) batch = iter.get_next() #x, y = melt.split_batch(batch, batch_size, num_gpus) x_, y_ = batch Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean epoch_loss_avg = Mean() epoch_valid_loss_avg = Mean() sess = melt.get_session(device_count={'GPU': 0}) global_step = 0 for epoch in range(start_epoch, num_epochs): melt.set_global('epoch', '%.4f' % (epoch)) sess.run(iter.initializer) model.train() #..... still OOM... FIXME TODO try: for _ in tqdm(range(num_steps_per_epoch), total=num_steps_per_epoch, ascii=True): x, y = sess.run([x_, y_]) x, y = to_torch(x, y) optimizer.zero_grad() loss = loss_fn(model, x, y) loss.backward() optimizer.step() epoch_loss_avg(loss) if global_step % FLAGS.interval_steps == 0: print(global_step, epoch_loss_avg.result().numpy()) global_step += 1 except tf.errors.OutOfRangeError: print('epoch:%d loss:%f' % (epoch, epoch_loss_avg.result().numpy()))
def train(Dataset, model, loss_fn, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, dataset=None, valid_dataset=None, test_dataset=None, sep=','): if Dataset is None: assert dataset if FLAGS.torch: # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = batch_size if FLAGS.fold is not None: inputs = [ x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold) ] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 train_dataset_ = dataset or Dataset('train') train_dataset = train_dataset_.make_batch(batch_size, inputs) num_examples = train_dataset_.num_examples_per_epoch('train') num_all_examples = num_examples valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [ x for x in all_inputs if not 'aug' in x and x not in inputs ] else: valid_inputs = [ x for x in all_inputs if 'aug' in x and x not in inputs ] logging.info('valid_inputs', valid_inputs) if valid_inputs: valid_dataset_ = valid_dataset or Dataset('valid') valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs) valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) num_valid_examples = None if FLAGS.valid_input: num_valid_examples = valid_dataset_.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -( -num_valid_examples // batch_size_) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None logging.info('num_valid_examples:', num_valid_examples) if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_inputs: test_dataset_ = test_dataset or Dataset('test') test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) num_test_examples = test_dataset_.num_examples_per_epoch('test') num_test_steps_per_epoch = -( -num_test_examples // batch_size_) if num_test_examples else None else: test_dataset = None logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer( FLAGS.optimizer)(learning_rate) except Exception: logging.warning( f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int( latest_checkpoint.split('-') [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 else: # TODO torch with learning rate adjust if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer = lele.training.optimizers.NoamOpt( 128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0)) elif FLAGS.optimizer == 'bert': num_train_steps = int( num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) num_warmup_steps = FLAGS.warmup_steps or int( num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, torch.optim.Adamax(model.parameters(), lr=0)) else: is_dynamic_opt = False optimizer = torch.optim.Adamax( param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join( FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid: will_valid = False if start_epoch > 0 and will_valid: will_valid = True if will_valid: logging.info('----------valid') if FLAGS.torch: model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if names: logging.info2( 'epoch:%d/%d' % (start_epoch, num_epochs), ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid': exit(0) if 'test' in FLAGS.work_mode: logging.info('--------test/inference') if test_dataset: if FLAGS.torch: model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean timer = gezi.Timer() num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or ( FLAGS.num_epochs_per_decay and num_steps_per_epoch ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int( num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int( num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info( 'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}' .format(FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) for epoch in range(start_epoch, num_epochs): melt.set_global('epoch', '%.4f' % (epoch)) if FLAGS.torch: model.train() epoch_loss_avg = Mean() epoch_valid_loss_avg = Mean() #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True): for i, (x, y) in enumerate(train_dataset): if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) #print(x, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) optimizer.apply_gradients(zip(grads, model.variables)) else: optimizer.zero_grad() if 'training' in inspect.getargspec(loss_fn).args: loss = loss_fn(model, x, y, training=True) else: loss = loss_fn(model, x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) epoch_loss_avg(loss) # add current batch loss if FLAGS.torch: del loss batch_size_ = list( x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type( {}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) if valid_dataset2: try: x, y = next(iter(valid_dataset2)) except Exception: # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset x, y = next(iter(valid_dataset2)) if FLAGS.torch: x, y = to_torch(x, y) model.eval() valid_loss = loss_fn(model, x, y) epoch_valid_loss_avg(valid_loss) if FLAGS.torch: model.train() logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss:[%.4f]' % epoch_valid_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_valid.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy()) summary.scalar( 'loss/eval', epoch_valid_loss_avg.result().numpy()) writer_valid.flush() else: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_train.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_loss_avg.result().numpy()) summary.scalar('loss/train_avg', epoch_loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('batch_size', batch_size_) summary.scalar('epoch', melt.epoch()) summary.scalar('steps_per_second', steps_per_second) summary.scalar('instances_per_second', instances_per_second) writer_train.flush() if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy( ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, sep=sep) if vals and names: with writer_valid.as_default( ), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step/valid/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() model.train() if names and vals: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) # if i == 5: # print(i, '---------------------save') # print(len(model.trainable_variables)) ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save # model.save_weights(os.path.join(ckpt_dir, 'weights')) # checkpoint.save(checkpoint_prefix) # exit(0) if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy( ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy( ) >= decay_start_step and global_step.numpy( ) % decay_steps == 0: lr = max( learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if epoch == start_epoch and i == 0: try: if not FLAGS.torch: logging.info(model.summary()) except Exception: traceback.print_exc() logging.info( 'Fail to do model.summary() may be you have layer define in init but not used in call' ) if 'SHOW' in os.environ: exit(0) logging.info2( 'epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size, 'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy()) timer = gezi.Timer( f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}', False) checkpoint.save(checkpoint_prefix) if FLAGS.torch and FLAGS.save_interval_epochs == 1: state = { 'epoch': epoch + 1, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt') timer.print_elapsed() if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint( ckpt_dir) names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if vals and names: logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'epoch_valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) with writer.as_default(), summary.always_record_summaries(): temp = global_step.value() global_step.assign(epoch + 1) summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy()) if valid_dataset: if FLAGS.torch: model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'epoch/valid/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0: if FLAGS.torch: model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir) print(command, file=sys.stderr) os.system(command)
def train(model, loss_fn, Dataset=None, dataset=None, valid_dataset=None, valid_dataset2=None, test_dataset=None, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, sep=','): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ if Dataset is None: assert dataset logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset', valid_dataset, 'test_dataset', test_dataset, loss_fn) if FLAGS.torch: torch.manual_seed(FLAGS.seed or 0) if torch.cuda.device_count(): torch.cuda.manual_seed(FLAGS.seed or 0) if use_horovod: pass # import horovod.torch as hvd # hvd.init() # #print('-----------------', hvd, hvd.size()) # assert hvd.mpi_threads_supported() # assert hvd.size() == comm.Get_size() # torch.cuda.set_device(hvd.local_rank()) # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html else: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = FLAGS.eval_batch_size or batch_size if dataset is None: if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold)] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 if dataset is None: dataset = Dataset('train') assert len(inputs) > 0 train_dataset = dataset.make_batch(batch_size, inputs, simple_parse=FLAGS.simple_parse) num_examples = dataset.num_examples_per_epoch('train') else: assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there' train_dataset = dataset num_examples = len(train_dataset.dataset) num_all_examples = num_examples if valid_dataset is None: valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [x for x in all_inputs if not 'aug' in x and x not in inputs] else: valid_inputs = [x for x in all_inputs if 'aug' in x and x not in inputs] logging.info('valid_inputs', valid_inputs) num_valid_examples = None if valid_dataset is not None: num_valid_examples = len(valid_dataset.dataset) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None valid_dataset2_iter = iter(valid_dataset2) else: if valid_inputs: valid_dataset = dataset.make_batch(batch_size_, valid_inputs, subset='valid', hvd_shard=FLAGS.horovod_eval ) valid_dataset2 = dataset.make_batch(batch_size, valid_inputs, subset='valid', repeat=True, initializable=False, hvd_shard=False) valid_dataset2_iter = iter(valid_dataset2) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) if use_horovod and num_examples: num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size())) if num_valid_examples is None: if FLAGS.valid_input: num_valid_examples = dataset.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None if use_horovod and FLAGS.horovod_eval and num_valid_examples: num_valid_steps_per_epoch = -(-num_valid_examples // (batch_size_ * hvd.size())) logging.info('num_valid_examples:', num_valid_examples) if test_dataset is None: if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_dataset is not None: num_test_examples = len(test_dataset.dataset) else: if test_inputs: test_dataset = dataset.make_batch(batch_size_, test_inputs, subset='test') num_test_examples = dataset.num_examples_per_epoch('test') else: test_dataset = None num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None if use_horovod and FLAGS.horovod_eval and num_test_examples: num_test_steps_per_epoch = -(-num_test_examples // (batch_size_ * hvd.size())) logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead. #logger = gezi.SummaryWriter(FLAGS.log_dir) learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer(FLAGS.optimizer)(learning_rate) except Exception: logging.warning(f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 start_step = 0 # TODO else: # TODO torch with learning rate adjust # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py # TODO full support for pytorch now not work if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.NoamOpt(128, 2, 4000, optimzier_) elif FLAGS.optimizer == 'bert': num_train_steps = int(num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) if FLAGS.warmup_steps and use_horovod: FLAGS.warmup_steps = max(int(FLAGS.warmup_steps / hvd.size()), 1) num_warmup_steps = FLAGS.warmup_steps or int(num_steps_per_epoch * FLAGS.warmup_epochs) or int(num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, optimizer_ ) else: is_dynamic_opt = False optimizer = torch.optim.Adamax(param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) if use_horovod: optimizer = hvd.DistributedOptimizer(optimizer) optimizer_ = optimizer else: if use_horovod: optimizer = hvd.DistributedOptimizer(optimizer) optimizer_ = optimizer start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate.numpy()) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if global_step.numpy() == 0 : will_valid = False if gezi.get_env('EVFIRST') == '1': will_valid = True if gezi.get_env('EVFIRST') == '0': will_valid = False if will_valid: logging.info('----------valid') if hasattr(model, 'eval'): model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if names: logging.info2('epoch:%.2f/%d step:%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs, global_step.numpy()), ['%s:%.4f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1': exit(0) if 'test' in FLAGS.work_mode or gezi.get_env('TEST') == '1' or gezi.get_env('INFER') == '1': logging.info('--------test/inference') if test_dataset: if hasattr(model, eval): model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or (FLAGS.num_epochs_per_decay and num_steps_per_epoch), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int(num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int(num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info('learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'.format( FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) #-------------------------start training if hasattr(model, 'train'): model.train() if use_horovod: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer_, root_rank=0) timer = gezi.Timer() loss_avg = Mean() valid_loss_avg = Mean() num_epochs = num_epochs if num_epochs else 0 loops = min(num_epochs, 1) if FLAGS.torch_only else 1 for _ in range(loops): for i, (x, y) in enumerate(train_dataset): if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) def loss_fn_(x, y): if not FLAGS.torch and 'training' in inspect.getargspec(model.call).args: y_ = model(x, training=True) else: y_ = model(x) if not FLAGS.torch: return loss_fn(y, y_) else: return loss_fn(y_, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) #optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. # Note: broadcast should be done after the first gradient step to ensure optimizer # initialization. # TODO check eager mode if use_horovod and epoch == start_epoch and i == 0: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizier.variables(), root_rank=0) else: optimizer.zero_grad() loss = loss_fn_(x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) loss_avg(loss) ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735 # if FLAGS.torch: # del loss batch_size_ = list(x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type({}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format(hours_per_epoch) if valid_dataset2: # try: # x, y = next(iter(valid_dataset2)) # except Exception: # # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset # x, y = next(iter(valid_dataset2)) ## valid dataset2 is repeated ## NOTICE will always the first batch ... as below #x, y = next(iter(valid_dataset2)) x, y = next(valid_dataset2_iter) #print(x['id'][0]) if FLAGS.torch: x, y = to_torch(x, y) if hasattr(model, 'eval'): model.eval() valid_loss = loss_fn_(x, y) valid_loss = valid_loss.numpy() if not FLAGS.torch else valid_loss.item() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: # 'train_loss:[%.4f]' % loss_avg.result().numpy(), # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy() logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy(), 'valid_loss:[%.4f]' % valid_loss ) if global_step.numpy() % FLAGS.valid_interval_steps == 0: with writer_valid.as_default(), summary.always_record_summaries(): summary.scalar('loss/valid', valid_loss) writer_valid.flush() else: if not use_horovod or hvd.rank() == 0: #'train_loss:[%.4f]' % loss_avg.result().numpy() logging.info2('epoch:%.2f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy() ) if not use_horovod or hvd.rank() == 0: if global_step.numpy() % FLAGS.valid_interval_steps == 0: with writer_train.as_default(), summary.always_record_summaries(): summary.scalar('loss/train_avg', loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('other/batch_size', batch_size_) summary.scalar('other/epoch', melt.epoch()) summary.scalar('perf/steps_per_second', steps_per_second) summary.scalar('perf/instances_per_second', instances_per_second) writer_train.flush() if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy() and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if hasattr(model, eval): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: with writer_valid.as_default(), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step_eval/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: if names and vals: logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if not use_horovod or hvd.rank() == 0: # TODO save ok ? if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer' : optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer' : optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy() >= decay_start_step and global_step.numpy() % decay_steps == 0: lr = max(learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if i == 0: try: if not FLAGS.torch: logging.info(model.summary()) # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB') # import keras # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96) else: logging.info(model) except Exception: traceback.print_exc() logging.info('Fail to do model.summary() may be you have layer define in init but not used in call') if 'SHOW' in os.environ: exit(0) if valid_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint(ckpt_dir) print('---------metric evaluate step', global_step.numpy(), 'model_path:', model_path) names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: logging.info2('epoch:%.2f/%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs), 'step:%d' % global_step.numpy(), 'valid_metrics', ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if not use_horovod or hvd.rank() == 0: with writer.as_default(), summary.always_record_summaries(): temp = global_step.value() global_step.assign(int(global_step.numpy() / int(num_steps_per_epoch * FLAGS.valid_interval_epochs))) if valid_dataset: if hasattr(model, 'eval'): model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'eval/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.inference_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if num_epochs and (global_step.numpy() % num_steps_per_epoch) == 0 and int(global_step.numpy() / num_steps_per_epoch) == num_epochs: logging.info(f'Finshed training of {num_epochs} epochs') exit(0)