def add(self, score): import melt.utils.logging as logging if not tf.executing_eagerly(): weight = self.sess.run(self.weight_op) else: weight = self.weight_op #print(weight, score, self.score, self.patience) if (not self.cmp) and self.score: if score > self.score: self.cmp = lambda x, y: x > y else: self.cmp = lambda x, y: x < y logging.info('decay cmp:', self.cmp) if not self.score or self.cmp(score, self.score): self.score = score self.patience = 0 else: self.patience += 1 # epoch is set during training loop epoch = melt.epoch() logging.info('patience:', self.patience) if epoch < self.decay_start_epoch: return if self.patience >= self.max_patience: self.count += 1 self.patience = 0 self.score = score decay = self.decay pre_weight = weight #weight *= decay weight = weight * decay # decay if self.min_weight and weight < self.min_weight: weight = self.min_weight decay = weight / pre_weight if decay > 1.: decay = 1. logging.info('!decay count:', self.count, self.name, 'now:', weight) if not tf.executing_eagerly(): self.sess.run(tf.assign(self.weight_op, tf.constant(weight, dtype=tf.float32))) else: self.weight_op = weight if 'learning_rate' in self.name: if not tf.executing_eagerly(): melt.multiply_learning_rate(tf.constant(decay, dtype=tf.float32), self.sess) else: # TODO need to test eager mode #learning_rate = tf.get_collection('learning_rate')[-1] #if learning_rate * decay > self.min_learning_rate: #tf.get_collection('learning_rate')[-1] *= decay tf.get_collection('learning_rate')[-1].assign(tf.get_collection('learning_rate')[-1] * decay) return weight
def print_img(img, i): img_url = FLAGS.image_url_prefix + img if not img.startswith( "http://") else img logging.info( img_html.format(img_url, i, img, melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss(), melt.duration(), gezi.now_time()))
def add(self, scores): import melt.utils.logging as logging scores = np.array(scores) #print(scores.shape, self.scores.shape, len(self.names)) logging.info('diff:', list(zip(self.names, scores - self.scores))) if not tf.executing_eagerly(): weights = self.sess.run(self.weights_op) weights_ = weights else: weights = self.weights_op weights_ = weights.numpy() if (not self.cmp) and self.scores: if scores[0] > self.scores[0]: self.cmp = lambda x, y: x > y else: self.cmp = lambda x, y: x < y logging.info('decay cmp:', self.cmp) # epoch is set during training loop epoch = melt.epoch() for i, score in enumerate(scores): if self.scores is None or self.cmp(score, self.scores[i]): self.scores[i] = score self.patience[i] = 0 else: self.patience[i] += 1 logging.info('patience_%s %d' % (self.names[i], self.patience[i])) if epoch < self.decay_start_epoch: continue if self.patience[i] >= self.max_patience: self.count[i] += 1 self.patience[i] = 0 self.scores[i] = score decay = self.decay if not isinstance(self.decay, (list, tuple)) else self.decay[i] weights_[i] *= decay if not self.min_weight: if weights_[i] < self.min_weight: weights_[i] = self.min_weight #logging.info('!%s decay count:%d decay ratio:%f lr ratios now:%f' % (self.names[i], self.count[i], self.decay, weights[i])) if not tf.executing_eagerly(): self.sess.run(tf.assign(self.weights_op, tf.constant(weights_, dtype=tf.float32))) else: self.weights_op.assign(weights_) return weights_
def print_img(img, i): img_url = get_img_url(img) logging.info(img_html.format( img_url, i, img, melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss(), melt.duration(), gezi.now_time()))
def rate(self, step=None): "Implement `lrate` above" if step is None: step = self._step warmup_percent_done = step / self.warmup warmup_learning_rate = self.lr * warmup_percent_done # decay by eval value ? if melt.epoch() >= 9 and FLAGS.num_epochs > 9: self.min_lr = self.ori_min_lr * ( (FLAGS.num_epochs - melt.epoch()) / (FLAGS.num_epochs - 5)) #print('-----------------', melt.epoch(), melt.epoch() > 9, self.min_lr, self.ori_min_lr) is_warmup = step < self.warmup learning_rate = lr_poly(self.lr, step, self.num_train_steps, self.min_lr, 1.) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) #print('-----------------', is_warmup, warmup_percent_done, warmup_learning_rate, warmup_learning_rate) return learning_rate
def unk_aug(self, x, x_mask=None): """ randomly make 10% words as unk TODO this works, but should this be rmoved and put it to Dataset so can share for both pyt and tf """ if not self.training or not FLAGS.unk_aug or melt.epoch() < FLAGS.unk_aug_start_epoch: return x if x_mask is None: x_mask = x > 0 x_mask = x_mask.long() ratio = np.random.uniform(0, FLAGS.unk_aug_max_ratio) mask = torch.cuda.FloatTensor(x.size(0), x.size(1)).uniform_() > ratio mask = mask.long() rmask = FLAGS.unk_id * (1 - mask) x = (x * mask + rmask) * x_mask return x
def unk_aug(self, x, x_mask=None): """ randomly make some words as unk """ if not self.training or not FLAGS.unk_aug or melt.epoch( ) < FLAGS.unk_aug_start_epoch: return x if x_mask is None: x_mask = x > 0 x_mask = x_mask.long() ratio = np.random.uniform(0, FLAGS.unk_aug_max_ratio) mask = torch.cuda.FloatTensor(x.size(0), x.size(1)).uniform_() > ratio mask = mask.long() rmask = FLAGS.unk_id * (1 - mask) x = (x * mask + rmask) * x_mask return x
def evaluate_scores(predictor, random=False): timer = gezi.Timer('evaluate_scores') init() imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) step = FLAGS.metric_eval_batch_size if random: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] text_max_words = all_distinct_texts.shape[1] rank_metrics = gezi.rank_metrics.RecallMetrics() print('text_max_words:', text_max_words) start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr) predicts(imgs[start:end], img_features[start:end], predictor, rank_metrics) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) timer.print() return rank_metrics.get_metrics(), rank_metrics.get_names()
def train(Dataset, model, loss_fn, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, dataset=None, valid_dataset=None, test_dataset=None, sep=','): if Dataset is None: assert dataset if FLAGS.torch: # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = batch_size if FLAGS.fold is not None: inputs = [ x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold) ] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 train_dataset_ = dataset or Dataset('train') train_dataset = train_dataset_.make_batch(batch_size, inputs) num_examples = train_dataset_.num_examples_per_epoch('train') num_all_examples = num_examples valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [ x for x in all_inputs if not 'aug' in x and x not in inputs ] else: valid_inputs = [ x for x in all_inputs if 'aug' in x and x not in inputs ] logging.info('valid_inputs', valid_inputs) if valid_inputs: valid_dataset_ = valid_dataset or Dataset('valid') valid_dataset = valid_dataset_.make_batch(batch_size_, valid_inputs) valid_dataset2 = valid_dataset_.make_batch(batch_size_, valid_inputs, repeat=True) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) num_valid_examples = None if FLAGS.valid_input: num_valid_examples = valid_dataset_.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -( -num_valid_examples // batch_size_) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None logging.info('num_valid_examples:', num_valid_examples) if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_inputs: test_dataset_ = test_dataset or Dataset('test') test_dataset = test_dataset_.make_batch(batch_size_, test_inputs) num_test_examples = test_dataset_.num_examples_per_epoch('test') num_test_steps_per_epoch = -( -num_test_examples // batch_size_) if num_test_examples else None else: test_dataset = None logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer( FLAGS.optimizer)(learning_rate) except Exception: logging.warning( f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int( latest_checkpoint.split('-') [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 else: # TODO torch with learning rate adjust if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer = lele.training.optimizers.NoamOpt( 128, 2, 4000, torch.optim.Adamax(model.parameters(), lr=0)) elif FLAGS.optimizer == 'bert': num_train_steps = int( num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) num_warmup_steps = FLAGS.warmup_steps or int( num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, torch.optim.Adamax(model.parameters(), lr=0)) else: is_dynamic_opt = False optimizer = torch.optim.Adamax( param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join( FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if start_epoch == 0 and not 'EVFIRST' in os.environ and will_valid: will_valid = False if start_epoch > 0 and will_valid: will_valid = True if will_valid: logging.info('----------valid') if FLAGS.torch: model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if names: logging.info2( 'epoch:%d/%d' % (start_epoch, num_epochs), ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid': exit(0) if 'test' in FLAGS.work_mode: logging.info('--------test/inference') if test_dataset: if FLAGS.torch: model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean timer = gezi.Timer() num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or ( FLAGS.num_epochs_per_decay and num_steps_per_epoch ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int( num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int( num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info( 'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}' .format(FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) for epoch in range(start_epoch, num_epochs): melt.set_global('epoch', '%.4f' % (epoch)) if FLAGS.torch: model.train() epoch_loss_avg = Mean() epoch_valid_loss_avg = Mean() #for i, (x, y) in tqdm(enumerate(train_dataset), total=num_steps_per_epoch, ascii=True): for i, (x, y) in enumerate(train_dataset): if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) #print(x, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) optimizer.apply_gradients(zip(grads, model.variables)) else: optimizer.zero_grad() if 'training' in inspect.getargspec(loss_fn).args: loss = loss_fn(model, x, y, training=True) else: loss = loss_fn(model, x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) epoch_loss_avg(loss) # add current batch loss if FLAGS.torch: del loss batch_size_ = list( x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type( {}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) if valid_dataset2: try: x, y = next(iter(valid_dataset2)) except Exception: # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset x, y = next(iter(valid_dataset2)) if FLAGS.torch: x, y = to_torch(x, y) model.eval() valid_loss = loss_fn(model, x, y) epoch_valid_loss_avg(valid_loss) if FLAGS.torch: model.train() logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss:[%.4f]' % epoch_valid_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_valid.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_valid_loss_avg.result().numpy()) summary.scalar( 'loss/eval', epoch_valid_loss_avg.result().numpy()) writer_valid.flush() else: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.3f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy()) if global_step.numpy() % FLAGS.eval_interval_steps == 0: with writer_train.as_default( ), summary.always_record_summaries(): #summary.scalar('step/loss', epoch_loss_avg.result().numpy()) summary.scalar('loss/train_avg', epoch_loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('batch_size', batch_size_) summary.scalar('epoch', melt.epoch()) summary.scalar('steps_per_second', steps_per_second) summary.scalar('instances_per_second', instances_per_second) writer_train.flush() if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy( ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, sep=sep) if vals and names: with writer_valid.as_default( ), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step/valid/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() model.train() if names and vals: logging.info2( 'epoch:%.3f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) # if i == 5: # print(i, '---------------------save') # print(len(model.trainable_variables)) ## TODO FIXME seems save weighs value not ok... not the same as checkpoint save # model.save_weights(os.path.join(ckpt_dir, 'weights')) # checkpoint.save(checkpoint_prefix) # exit(0) if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and FLAGS.save_interval_epochs < 1 and global_step.numpy( ) % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: #if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': epoch, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy( ) >= decay_start_step and global_step.numpy( ) % decay_steps == 0: lr = max( learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if epoch == start_epoch and i == 0: try: if not FLAGS.torch: logging.info(model.summary()) except Exception: traceback.print_exc() logging.info( 'Fail to do model.summary() may be you have layer define in init but not used in call' ) if 'SHOW' in os.environ: exit(0) logging.info2( 'epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'batch_size:[%d]' % batch_size, 'gpus:[%d]' % num_gpus, 'lr:[%.8f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % epoch_loss_avg.result().numpy(), 'valid_loss::[%.4f]' % epoch_valid_loss_avg.result().numpy()) timer = gezi.Timer( f'save model to {checkpoint_prefix}-{checkpoint.save_counter.numpy() + 1}', False) checkpoint.save(checkpoint_prefix) if FLAGS.torch and FLAGS.save_interval_epochs == 1: state = { 'epoch': epoch + 1, 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir) + '.pyt') timer.print_elapsed() if valid_dataset and (epoch + 1) % FLAGS.valid_interval_epochs == 0: if FLAGS.torch: model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint( ckpt_dir) names = valid_names if valid_names is not None else [ infer_names[0] ] + [x + '_y' for x in infer_names[1:] ] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, suffix=valid_suffix, sep=sep) if vals and names: logging.info2('epoch:%d/%d' % (epoch + 1, num_epochs), 'step:%d' % global_step.numpy(), 'epoch_valid_metrics', [ '%s:%.5f' % (name, val) for name, val in zip(names, vals) ]) with writer.as_default(), summary.always_record_summaries(): temp = global_step.value() global_step.assign(epoch + 1) summary.scalar('epoch/train/loss', epoch_loss_avg.result().numpy()) if valid_dataset: if FLAGS.torch: model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'epoch/valid/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and (epoch + 1) % FLAGS.inference_interval_epochs == 0: if FLAGS.torch: model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if FLAGS.log_dir != FLAGS.model_dir: assert FLAGS.log_dir command = 'rsync -l -r -t %s/* %s' % (FLAGS.log_dir, FLAGS.model_dir) print(command, file=sys.stderr) os.system(command) command = 'rm -rf %s/latest.pyt.*' % (FLAGS.model_dir) print(command, file=sys.stderr) os.system(command)
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, valid_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, use_horovod=False, ): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0 summary_str = [] eval_str = '' if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) # if eval ops then should have bee rank 0 if eval_ops: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # if use horovod let each rant use same sess.run! if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY') or use_horovod: #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False if use_horovod: sess.run(hvd.allreduce(tf.constant(0))) #if not use_horovod or hvd.local_rank() == 0: # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) #if not use_horovod or hvd.rank() == 0: # logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_))) eval_str = 'valid:{}'.format( melt.parse_results(eval_loss, eval_names_)) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) if not use_horovod or hvd.rank() == 0: melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != valid_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True if 'EVFIRST' in os.environ: if os.environ['EVFIRST'] == '0': if is_start: metric_evaluate = False else: if is_start: metric_evaluate = True if step == 0 or 'QUICK' in os.environ: metric_evaluate = False #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank()) if metric_evaluate: if use_horovod: print('------------metric evaluate step', step, model_path, hvd.rank()) if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: metric_eval_fn_ = metric_eval_fn else: metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path) try: l = metric_eval_fn_() if isinstance(l, tuple): num_returns = len(l) if num_returns == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok' evaluate_results, evaluate_names, evaluate_summaries = l else: #return dict evaluate_results, evaluate_names = tuple(zip(*dict.items())) evaluate_summaries = None except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) if not use_horovod or hvd.rank() == 0: #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None or use_horovod: feed_dict[K.learning_phase()] = 1 results = sess.run(ops, feed_dict=feed_dict) else: ## TODO why below ? #try: feed_dict[K.learning_phase()] = 1 results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 #if is_start or interval_steps and step % interval_steps == 0: interval_ok = not use_horovod or hvd.local_rank() == 0 if interval_steps and step % interval_steps == 0 and interval_ok: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.2f} '.format(duration) melt.set_global('duration', '%.2f' % duration) #info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) info.write(eval_str) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: summary = tf.Summary() if eval_ops is None: if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'valid' if not eval_names else '' # loss/valid melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: try: # loss/train_avg melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') except Exception: pass ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size', prefix='other') melt.add_summary(summary, melt.epoch(), 'epoch', prefix='other') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second', prefix='perf') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second', prefix='perf') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch', prefix='perf') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step # eval/loss eval/auc .. melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() return stop elif metric_evaluate and log_dir: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step_eval' if model_path: prefix = 'eval' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()
def get_num_finetune_words(): if not FLAGS.dynamic_finetune: return FLAGS.num_finetune_words else: return min(int(melt.epoch() * 1000), FLAGS.num_finetune_words)
def evaluate_scores(predictor, random=False): timer = gezi.Timer('evaluate_scores') init() if FLAGS.eval_img2text: imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) step = FLAGS.metric_eval_batch_size if random: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] rank_metrics = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr) predicts(imgs[start:end], img_features[start:end], predictor, rank_metrics) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) if FLAGS.eval_text2img: num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts)) if random: index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False) text_strs = all_distinct_text_strs[index] texts = all_distinct_texts[index] rank_metrics2 = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr) predicts_txt2im(text_strs[start:end], texts[start:end], predictor, rank_metrics2) start = end melt.logging_results(rank_metrics2.get_metrics(), ['t2i' + x for x in rank_metrics2.get_names()], tag='text2img') timer.print() if FLAGS.eval_img2text and FLAGS.eval_text2img: return rank_metrics.get_metrics() + rank_metrics2.get_metrics( ), rank_metrics.get_names() + [ 't2i' + x for x in rank_metrics2.get_names() ] elif FLAGS.eval_img2text: return rank_metrics.get_metrics(), rank_metrics.get_names() else: return rank_metrics2.get_metrics(), rank_metrics2.get_names()
def train(model, loss_fn, Dataset=None, dataset=None, valid_dataset=None, valid_dataset2=None, test_dataset=None, evaluate_fn=None, inference_fn=None, eval_fn=None, write_valid=True, valid_names=None, infer_names=None, infer_debug_names=None, valid_write_fn=None, infer_write_fn=None, valid_suffix='.valid', infer_suffix='.infer', write_streaming=False, optimizer=None, param_groups=None, init_fn=None, sep=','): use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ if Dataset is None: assert dataset logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset', valid_dataset, 'test_dataset', test_dataset, loss_fn) if FLAGS.torch: torch.manual_seed(FLAGS.seed or 0) if torch.cuda.device_count(): torch.cuda.manual_seed(FLAGS.seed or 0) if use_horovod: pass # import horovod.torch as hvd # hvd.init() # #print('-----------------', hvd, hvd.size()) # assert hvd.mpi_threads_supported() # assert hvd.size() == comm.Get_size() # torch.cuda.set_device(hvd.local_rank()) # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html else: if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model.to(device) input_ = FLAGS.train_input inputs = gezi.list_files(input_) inputs.sort() all_inputs = inputs #batch_size = FLAGS.batch_size batch_size = melt.batch_size() num_gpus = melt.num_gpus() #batch_size = max(batch_size, 1) #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1]) batch_size_ = FLAGS.eval_batch_size or batch_size if dataset is None: if FLAGS.fold is not None: inputs = [x for x in inputs if not x.endswith('%d.record' % FLAGS.fold) and not x.endswith('%d.tfrecord' % FLAGS.fold)] # if FLAGS.valid_input: # inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)] logging.info('inputs', len(inputs), inputs[:100]) num_folds = FLAGS.num_folds or len(inputs) + 1 if dataset is None: dataset = Dataset('train') assert len(inputs) > 0 train_dataset = dataset.make_batch(batch_size, inputs, simple_parse=FLAGS.simple_parse) num_examples = dataset.num_examples_per_epoch('train') else: assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there' train_dataset = dataset num_examples = len(train_dataset.dataset) num_all_examples = num_examples if valid_dataset is None: valid_inputs = None if FLAGS.valid_input: valid_inputs = gezi.list_files(FLAGS.valid_input) else: if FLAGS.fold is not None: #valid_inputs = [x for x in all_inputs if x not in inputs] if not FLAGS.test_aug: valid_inputs = [x for x in all_inputs if not 'aug' in x and x not in inputs] else: valid_inputs = [x for x in all_inputs if 'aug' in x and x not in inputs] logging.info('valid_inputs', valid_inputs) num_valid_examples = None if valid_dataset is not None: num_valid_examples = len(valid_dataset.dataset) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None valid_dataset2_iter = iter(valid_dataset2) else: if valid_inputs: valid_dataset = dataset.make_batch(batch_size_, valid_inputs, subset='valid', hvd_shard=FLAGS.horovod_eval ) valid_dataset2 = dataset.make_batch(batch_size, valid_inputs, subset='valid', repeat=True, initializable=False, hvd_shard=False) valid_dataset2_iter = iter(valid_dataset2) else: valid_datsset = None valid_dataset2 = None if num_examples: if FLAGS.fold is not None: num_examples = int(num_examples * (num_folds - 1) / num_folds) num_steps_per_epoch = -(-num_examples // batch_size) else: num_steps_per_epoch = None logging.info('num_train_examples:', num_examples) if use_horovod and num_examples: num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size())) if num_valid_examples is None: if FLAGS.valid_input: num_valid_examples = dataset.num_examples_per_epoch('valid') num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) if num_valid_examples else None else: if FLAGS.fold is not None: if num_examples: num_valid_examples = int(num_all_examples * (1 / num_folds)) num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_) else: num_valid_steps_per_epoch = None if use_horovod and FLAGS.horovod_eval and num_valid_examples: num_valid_steps_per_epoch = -(-num_valid_examples // (batch_size_ * hvd.size())) logging.info('num_valid_examples:', num_valid_examples) if test_dataset is None: if FLAGS.test_input: test_inputs = gezi.list_files(FLAGS.test_input) #test_inputs = [x for x in test_inputs if not 'aug' in x] logging.info('test_inputs', test_inputs) else: test_inputs = None num_test_examples = None if test_dataset is not None: num_test_examples = len(test_dataset.dataset) else: if test_inputs: test_dataset = dataset.make_batch(batch_size_, test_inputs, subset='test') num_test_examples = dataset.num_examples_per_epoch('test') else: test_dataset = None num_test_steps_per_epoch = -(-num_test_examples // batch_size_) if num_test_examples else None if use_horovod and FLAGS.horovod_eval and num_test_examples: num_test_steps_per_epoch = -(-num_test_examples // (batch_size_ * hvd.size())) logging.info('num_test_examples:', num_test_examples) summary = tf.contrib.summary # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch') # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train') # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid') writer = summary.create_file_writer(FLAGS.log_dir) writer_train = summary.create_file_writer(FLAGS.log_dir) writer_valid = summary.create_file_writer(FLAGS.log_dir) global_step = tf.train.get_or_create_global_step() ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead. #logger = gezi.SummaryWriter(FLAGS.log_dir) learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate") tf.add_to_collection('learning_rate', learning_rate) learning_rate_weight = tf.get_collection('learning_rate_weight')[-1] try: learning_rate_weights = tf.get_collection('learning_rate_weights')[-1] except Exception: learning_rate_weights = None # ckpt dir save models one per epoch ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt') os.system('mkdir -p %s' % ckpt_dir) # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2') os.system('mkdir -p %s' % ckpt_dir2) #TODO FIXME now I just changed tf code so to not by default save only latest 5 # refer to https://github.com/tensorflow/tensorflow/issues/22036 # manager = tf.contrib.checkpoint.CheckpointManager( # checkpoint, directory=ckpt_dir, max_to_keep=5) # latest_checkpoint = manager.latest_checkpoint latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir) if latest_checkpoint: logging.info('Latest checkpoint:', latest_checkpoint) else: latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2) logging.info('Latest checkpoint:', latest_checkpoint) if os.path.exists(FLAGS.model_dir + '.index'): latest_checkpoint = FLAGS.model_dir if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode: #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir latest_checkpoint = FLAGS.model_dir #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint) checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt') checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt') if not FLAGS.torch: try: optimizer = optimizer or melt.get_optimizer(FLAGS.optimizer)(learning_rate) except Exception: logging.warning(f'Fail to using {FLAGS.optimizer} use adam instead') optimizer = melt.get_optimizer('adam')(learning_rate) # TODO... if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, model=model, optimizer=optimizer, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, model=model, optimizer=optimizer, global_step=global_step) checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) start_epoch = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0 start_step = 0 # TODO else: # TODO torch with learning rate adjust # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py # TODO full support for pytorch now not work if optimizer is None: import lele is_dynamic_opt = True if FLAGS.optimizer == 'noam': optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.NoamOpt(128, 2, 4000, optimzier_) elif FLAGS.optimizer == 'bert': num_train_steps = int(num_steps_per_epoch * (FLAGS.num_decay_epochs or FLAGS.num_epochs)) if FLAGS.warmup_steps and use_horovod: FLAGS.warmup_steps = max(int(FLAGS.warmup_steps / hvd.size()), 1) num_warmup_steps = FLAGS.warmup_steps or int(num_steps_per_epoch * FLAGS.warmup_epochs) or int(num_train_steps * FLAGS.warmup_proportion) logging.info('num_train_steps', num_train_steps, 'num_warmup_steps', num_warmup_steps, 'warmup_proportion', FLAGS.warmup_proportion) optimizer_ = torch.optim.Adamax(model.parameters(), lr=0) if use_horovod: optimizer_ = hvd.DistributedOptimizer(optimizer_) optimizer = lele.training.optimizers.BertOpt( FLAGS.learning_rate, FLAGS.min_learning_rate, num_train_steps, num_warmup_steps, optimizer_ ) else: is_dynamic_opt = False optimizer = torch.optim.Adamax(param_groups if param_groups else model.parameters(), lr=FLAGS.learning_rate) if use_horovod: optimizer = hvd.DistributedOptimizer(optimizer) optimizer_ = optimizer else: if use_horovod: optimizer = hvd.DistributedOptimizer(optimizer) optimizer_ = optimizer start_epoch = 0 latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(FLAGS.model_dir, 'latest.pyt') if not os.path.exists(latest_path): latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt') if os.path.exists(latest_path): logging.info('loading torch model from', latest_path) checkpoint = torch.load(latest_path) if not FLAGS.torch_finetune: start_epoch = checkpoint['epoch'] step = checkpoint['step'] global_step.assign(step + 1) load_torch_model(model, latest_path) if FLAGS.torch_load_optimizer: optimizer.load_state_dict(checkpoint['optimizer']) # TODO by this way restart can not change learning rate.. if learning_rate_weights is None: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, global_step=global_step) else: checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, learning_rate_weight=learning_rate_weight, learning_rate_weights=learning_rate_weights, global_step=global_step) try: checkpoint.restore(latest_checkpoint) checkpoint2 = copy.deepcopy(checkpoint) except Exception: pass if FLAGS.torch and is_dynamic_opt: optimizer._step = global_step.numpy() #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1')) #model.save('./weight3.hd5') logging.info('optimizer:', optimizer) if FLAGS.torch_lr: learning_rate.assign(optimizer.rate(1)) if FLAGS.torch: learning_rate.assign(optimizer.param_groups[0]['lr']) logging.info('learning rate got from pytorch latest.py as', learning_rate.numpy()) learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor) if learning_rate_weights is not None: learning_rate_weights.assign(learning_rate_weights * FLAGS.learning_rate_start_factor) # TODO currently not support 0.1 epoch.. like this num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024 will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ if global_step.numpy() == 0 : will_valid = False if gezi.get_env('EVFIRST') == '1': will_valid = True if gezi.get_env('EVFIRST') == '0': will_valid = False if will_valid: logging.info('----------valid') if hasattr(model, 'eval'): model.eval() names = None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else latest_checkpoint names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None logging.info('model_path:', model_path, 'model_dir:', FLAGS.model_dir) vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if names: logging.info2('epoch:%.2f/%d step:%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs, global_step.numpy()), ['%s:%.4f' % (name, val) for name, val in zip(names, vals)]) if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1': exit(0) if 'test' in FLAGS.work_mode or gezi.get_env('TEST') == '1' or gezi.get_env('INFER') == '1': logging.info('--------test/inference') if test_dataset: if hasattr(model, eval): model.eval() if inference_fn is None: # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint # logging.info('model_path', model_path) assert latest_checkpoint inference(model, test_dataset, latest_checkpoint, infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) exit(0) if 'SHOW' in os.environ: num_epochs = start_epoch + 1 class PytObj(object): def __init__(self, x): self.x = x def numpy(self): return self.x class PytMean(object): def __init__(self): self._val = 0. self.count = 0 self.is_call = True def clear(self): self._val = 0 self.count = 0 def __call__(self, val): if not self.is_call: self.clear() self.is_call = True self._val += val.item() self.count += 1 def result(self): if self.is_call: self.is_call = False if not self.count: val = 0 else: val = self._val / self.count # TODO just for compact with tf .. return PytObj(val) Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean num_insts = 0 if FLAGS.learning_rate_decay_factor > 0: #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?' #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay #since global step / decay_steps will not be correct epoch as num_steps per epoch changed #so if if you change batch set you have to reset global step as fixed step assert FLAGS.num_steps_per_decay or (FLAGS.num_epochs_per_decay and num_steps_per_epoch), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch' decay_steps = FLAGS.num_steps_per_decay or int(num_steps_per_epoch * FLAGS.num_epochs_per_decay) decay_start_step = FLAGS.decay_start_step or int(num_steps_per_epoch * FLAGS.decay_start_epoch) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) logging.info('learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'.format( FLAGS.learning_rate_decay_factor, FLAGS.num_epochs_per_decay, decay_steps, FLAGS.decay_start_epoch, decay_start_step)) #-------------------------start training if hasattr(model, 'train'): model.train() if use_horovod: hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer_, root_rank=0) timer = gezi.Timer() loss_avg = Mean() valid_loss_avg = Mean() num_epochs = num_epochs if num_epochs else 0 loops = min(num_epochs, 1) if FLAGS.torch_only else 1 for _ in range(loops): for i, (x, y) in enumerate(train_dataset): if FLAGS.torch: x, y = to_torch(x, y) if is_dynamic_opt: learning_rate.assign(optimizer.rate()) def loss_fn_(x, y): if not FLAGS.torch and 'training' in inspect.getargspec(model.call).args: y_ = model(x, training=True) else: y_ = model(x) if not FLAGS.torch: return loss_fn(y, y_) else: return loss_fn(y_, y) if not FLAGS.torch: loss, grads = melt.eager.grad(model, x, y, loss_fn) grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients) #optimizer.apply_gradients(zip(grads, model.variables)) optimizer.apply_gradients(zip(grads, model.trainable_variables)) # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. # Note: broadcast should be done after the first gradient step to ensure optimizer # initialization. # TODO check eager mode if use_horovod and epoch == start_epoch and i == 0: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizier.variables(), root_rank=0) else: optimizer.zero_grad() loss = loss_fn_(x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), FLAGS.clip_gradients) optimizer.step() global_step.assign_add(1) loss_avg(loss) ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735 # if FLAGS.torch: # del loss batch_size_ = list(x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type({}) else x.shape[FLAGS.batch_size_dim] num_insts += int(batch_size_) if global_step.numpy() % FLAGS.interval_steps == 0: #checkpoint.save(checkpoint_prefix) elapsed = timer.elapsed() steps_per_second = FLAGS.interval_steps / elapsed instances_per_second = num_insts / elapsed num_insts = 0 if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600 epoch_time_info = '1epoch:[{:.2f}h]'.format(hours_per_epoch) if valid_dataset2: # try: # x, y = next(iter(valid_dataset2)) # except Exception: # # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset # x, y = next(iter(valid_dataset2)) ## valid dataset2 is repeated ## NOTICE will always the first batch ... as below #x, y = next(iter(valid_dataset2)) x, y = next(valid_dataset2_iter) #print(x['id'][0]) if FLAGS.torch: x, y = to_torch(x, y) if hasattr(model, 'eval'): model.eval() valid_loss = loss_fn_(x, y) valid_loss = valid_loss.numpy() if not FLAGS.torch else valid_loss.item() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: # 'train_loss:[%.4f]' % loss_avg.result().numpy(), # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy() logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy(), 'valid_loss:[%.4f]' % valid_loss ) if global_step.numpy() % FLAGS.valid_interval_steps == 0: with writer_valid.as_default(), summary.always_record_summaries(): summary.scalar('loss/valid', valid_loss) writer_valid.flush() else: if not use_horovod or hvd.rank() == 0: #'train_loss:[%.4f]' % loss_avg.result().numpy() logging.info2('epoch:%.2f/%d' % ((epoch + i / num_steps_per_epoch), num_epochs), 'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' % batch_size_, 'gpus:[%d]' % num_gpus, 'batches/s:[%.2f]' % steps_per_second, 'insts/s:[%d]' % instances_per_second, '%s' % epoch_time_info, 'lr:[%.6f]' % learning_rate.numpy(), 'train_loss:[%.4f]' % loss_avg.result().numpy() ) if not use_horovod or hvd.rank() == 0: if global_step.numpy() % FLAGS.valid_interval_steps == 0: with writer_train.as_default(), summary.always_record_summaries(): summary.scalar('loss/train_avg', loss_avg.result().numpy()) summary.scalar('learning_rate', learning_rate.numpy()) summary.scalar('other/batch_size', batch_size_) summary.scalar('other/epoch', melt.epoch()) summary.scalar('perf/steps_per_second', steps_per_second) summary.scalar('perf/instances_per_second', instances_per_second) writer_train.flush() if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy() and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0: if hasattr(model, eval): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, None, num_valid_steps_per_epoch) elif eval_fn: names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, None, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: with writer_valid.as_default(), summary.always_record_summaries(): for name, val in zip(names, vals): summary.scalar(f'step_eval/{name}', val) writer_valid.flush() if FLAGS.torch: if not FLAGS.torch_lr: # control learning rate by tensorflow learning rate for param_group in optimizer.param_groups: # important learning rate decay param_group['lr'] = learning_rate.numpy() if hasattr(model, 'train'): model.train() if not use_horovod or hvd.rank() == 0: if names and vals: logging.info2('epoch:%.2f/%d' % ((global_step.numpy() / num_steps_per_epoch), num_epochs), 'valid_step:%d' % global_step.numpy(), 'valid_metrics', ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if not use_horovod or hvd.rank() == 0: # TODO save ok ? if global_step.numpy() % FLAGS.save_interval_steps == 0: if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer' : optimizer.state_dict(), } torch.save(state, os.path.join(FLAGS.model_dir, 'latest.pyt')) # TODO fixme why if both checpoint2 and chekpoint used... not ok.. if FLAGS.save_interval_epochs and global_step.numpy() % int(num_steps_per_epoch * FLAGS.save_interval_epochs) == 0: checkpoint2.save(checkpoint_prefix2) if FLAGS.torch: state = { 'epoch': int(global_step.numpy() / num_steps_per_epoch), 'step': global_step.numpy(), 'state_dict': model.state_dict() if not hasattr(model, 'module') else model.module.state_dict(), 'optimizer' : optimizer.state_dict(), } torch.save(state, tf.train.latest_checkpoint(ckpt_dir2) + '.pyt') if FLAGS.learning_rate_decay_factor > 0: if global_step.numpy() >= decay_start_step and global_step.numpy() % decay_steps == 0: lr = max(learning_rate.numpy() * FLAGS.learning_rate_decay_factor, FLAGS.min_learning_rate) if lr < learning_rate.numpy(): learning_rate.assign(lr) if FLAGS.torch: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate.numpy() if i == 0: try: if not FLAGS.torch: logging.info(model.summary()) # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB') # import keras # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96) else: logging.info(model) except Exception: traceback.print_exc() logging.info('Fail to do model.summary() may be you have layer define in init but not used in call') if 'SHOW' in os.environ: exit(0) if valid_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() vals, names = None, None if evaluate_fn is not None: vals, names = evaluate_fn(model, valid_dataset, tf.train.latest_checkpoint(ckpt_dir), num_valid_steps_per_epoch) elif eval_fn: model_path = None if not write_valid else tf.train.latest_checkpoint(ckpt_dir) print('---------metric evaluate step', global_step.numpy(), 'model_path:', model_path) names = valid_names if valid_names is not None else [infer_names[0]] + [x + '_y' for x in infer_names[1:]] + infer_names[1:] if infer_names else None vals, names = evaluate(model, valid_dataset, eval_fn, model_path, names, valid_write_fn, write_streaming, num_valid_steps_per_epoch, num_valid_examples, suffix=valid_suffix, sep=sep) if not use_horovod or hvd.rank() == 0: if vals and names: logging.info2('epoch:%.2f/%d' % (global_step.numpy() / num_steps_per_epoch, num_epochs), 'step:%d' % global_step.numpy(), 'valid_metrics', ['%s:%.5f' % (name, val) for name, val in zip(names, vals)]) if not use_horovod or hvd.rank() == 0: with writer.as_default(), summary.always_record_summaries(): temp = global_step.value() global_step.assign(int(global_step.numpy() / int(num_steps_per_epoch * FLAGS.valid_interval_epochs))) if valid_dataset: if hasattr(model, 'eval'): model.eval() if vals and names: for name, val in zip(names, vals): summary.scalar(f'eval/{name}', val) writer.flush() global_step.assign(temp) if test_dataset and global_step.numpy() % int(num_steps_per_epoch * FLAGS.inference_interval_epochs) == 0: if hasattr(model, 'eval'): model.eval() if inference_fn is None: inference(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), infer_names, infer_debug_names, infer_write_fn, write_streaming, num_test_steps_per_epoch, num_test_examples, suffix=infer_suffix, sep=sep) else: inference_fn(model, test_dataset, tf.train.latest_checkpoint(ckpt_dir), num_test_steps_per_epoch) if num_epochs and (global_step.numpy() % num_steps_per_epoch) == 0 and int(global_step.numpy() / num_steps_per_epoch) == num_epochs: logging.info(f'Finshed training of {num_epochs} epochs') exit(0)
def train_once( sess, step, ops, names=None, gen_feed_dict_fn=None, deal_results_fn=None, interval_steps=100, eval_ops=None, eval_names=None, gen_eval_feed_dict_fn=None, deal_eval_results_fn=melt.print_results, eval_interval_steps=100, print_time=True, print_avg_loss=True, model_dir=None, log_dir=None, is_start=False, num_steps_per_epoch=None, metric_eval_fn=None, metric_eval_interval_steps=0, summary_excls=None, fixed_step=None, # for epoch only, incase you change batch size eval_loops=1, learning_rate=None, learning_rate_patience=None, learning_rate_decay_factor=None, num_epochs=None, model_path=None, ): #is_start = False # force not to evaluate at first step #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step())) timer = gezi.Timer() if print_time: if not hasattr(train_once, 'timer'): train_once.timer = Timer() train_once.eval_timer = Timer() train_once.metric_eval_timer = Timer() melt.set_global('step', step) epoch = (fixed_step or step) / num_steps_per_epoch if num_steps_per_epoch else -1 if not num_epochs: epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else '' else: epoch_str = 'epoch:%.3f/%d' % ( epoch, num_epochs) if num_steps_per_epoch else '' melt.set_global('epoch', '%.2f' % (epoch)) info = IO() stop = False if eval_names is None: if names: eval_names = ['eval/' + x for x in names] if names: names = ['train/' + x for x in names] if eval_names: eval_names = ['eval/' + x for x in eval_names] is_eval_step = is_start or eval_interval_steps and step % eval_interval_steps == 0 summary_str = [] if is_eval_step: # deal with summary if log_dir: if not hasattr(train_once, 'summary_op'): #melt.print_summary_ops() if summary_excls is None: train_once.summary_op = tf.summary.merge_all() else: summary_ops = [] for op in tf.get_collection(tf.GraphKeys.SUMMARIES): for summary_excl in summary_excls: if not summary_excl in op.name: summary_ops.append(op) print('filtered summary_ops:') for op in summary_ops: print(op) train_once.summary_op = tf.summary.merge(summary_ops) #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) train_once.summary_writer = tf.summary.FileWriter( log_dir, sess.graph) tf.contrib.tensorboard.plugins.projector.visualize_embeddings( train_once.summary_writer, projector_config) if eval_ops is not None: #if deal_eval_results_fn is None and eval_names is not None: # deal_eval_results_fn = lambda x: melt.print_results(x, eval_names) for i in range(eval_loops): eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn( ) #eval_feed_dict.update(feed_dict) # might use EVAL_NO_SUMMARY if some old code has problem TODO CHECK if not log_dir or train_once.summary_op is None or gezi.env_has( 'EVAL_NO_SUMMARY'): #if not log_dir or train_once.summary_op is None: eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict) else: eval_results = sess.run(eval_ops + [train_once.summary_op], feed_dict=eval_feed_dict) summary_str = eval_results[-1] eval_results = eval_results[:-1] eval_loss = gezi.get_singles(eval_results) #timer_.print() eval_stop = False # @TODO user print should also use logging as a must ? #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='') eval_names_ = melt.adjust_names(eval_loss, eval_names) logging.info2('{} eval_step:{} eval_metrics:{}'.format( epoch_str, step, melt.parse_results(eval_loss, eval_names_))) # if deal_eval_results_fn is not None: # eval_stop = deal_eval_results_fn(eval_results) assert len(eval_loss) > 0 if eval_stop is True: stop = True eval_names_ = melt.adjust_names(eval_loss, eval_names) melt.set_global('eval_loss', melt.parse_results(eval_loss, eval_names_)) elif interval_steps != eval_interval_steps: #print() pass metric_evaluate = False # if metric_eval_fn is not None \ # and (is_start \ # or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \ # or (metric_eval_interval_steps \ # and step % metric_eval_interval_steps == 0)): # metric_evaluate = True if metric_eval_fn is not None \ and ((is_start or metric_eval_interval_steps \ and step % metric_eval_interval_steps == 0) or model_path): metric_evaluate = True #if (is_start or step == 0) and (not 'EVFIRST' in os.environ): if ((step == 0) and (not 'EVFIRST' in os.environ)) or ('QUICK' in os.environ) or ( 'EVFIRST' in os.environ and os.environ['EVFIRST'] == '0'): metric_evaluate = False if metric_evaluate: # TODO better if not model_path or 'model_path' not in inspect.getargspec( metric_eval_fn).args: l = metric_eval_fn() if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l else: try: l = metric_eval_fn(model_path=model_path) if len(l) == 2: evaluate_results, evaluate_names = l evaluate_summaries = None else: evaluate_results, evaluate_names, evaluate_summaries = l except Exception: logging.info('Do nothing for metric eval fn with exception:\n', traceback.format_exc()) logging.info2('{} valid_step:{} {}:{}'.format( epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names))) if learning_rate is not None and (learning_rate_patience and learning_rate_patience > 0): assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1 valid_loss = evaluate_results[0] if not hasattr(train_once, 'min_valid_loss'): train_once.min_valid_loss = valid_loss train_once.deacy_steps = [] train_once.patience = 0 else: if valid_loss < train_once.min_valid_loss: train_once.min_valid_loss = valid_loss train_once.patience = 0 else: train_once.patience += 1 logging.info2('{} valid_step:{} patience:{}'.format( epoch_str, step, train_once.patience)) if learning_rate_patience and train_once.patience >= learning_rate_patience: lr_op = ops[1] lr = sess.run(lr_op) * learning_rate_decay_factor train_once.deacy_steps.append(step) logging.info2( '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}' .format(epoch_str, step, learning_rate_decay_factor, ','.join(map(str, train_once.deacy_steps)))) sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32))) train_once.patience = 0 train_once.min_valid_loss = valid_loss if ops is not None: #if deal_results_fn is None and names is not None: # deal_results_fn = lambda x: melt.print_results(x, names) feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn() # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar #print('---------------ops', ops) if eval_ops is not None or not log_dir or not hasattr( train_once, 'summary_op') or train_once.summary_op is None: results = sess.run(ops, feed_dict=feed_dict) else: #try: results = sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) summary_str = results[-1] results = results[:-1] # except Exception: # logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail') # results = sess.run(ops, feed_dict=feed_dict) #print('------------results', results) # #--------trace debug # if step == 210: # run_metadata = tf.RunMetadata() # results = sess.run( # ops, # feed_dict=feed_dict, # options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), # run_metadata=run_metadata) # from tensorflow.python.client import timeline # trace = timeline.Timeline(step_stats=run_metadata.step_stats) # trace_file = open('timeline.ctf.json', 'w') # trace_file.write(trace.generate_chrome_trace_format()) #reults[0] assume to be train_op, results[1] to be learning_rate learning_rate = results[1] results = results[2:] #@TODO should support aver loss and other avg evaluations like test.. if print_avg_loss: if not hasattr(train_once, 'avg_loss'): train_once.avg_loss = AvgScore() if interval_steps != eval_interval_steps: train_once.avg_loss2 = AvgScore() #assume results[0] as train_op return, results[1] as loss loss = gezi.get_singles(results) train_once.avg_loss.add(loss) if interval_steps != eval_interval_steps: train_once.avg_loss2.add(loss) steps_per_second = None instances_per_second = None hours_per_epoch = None #step += 1 if is_start or interval_steps and step % interval_steps == 0: train_average_loss = train_once.avg_loss.avg_score() if print_time: duration = timer.elapsed() duration_str = 'duration:{:.3f} '.format(duration) melt.set_global('duration', '%.3f' % duration) info.write(duration_str) elapsed = train_once.timer.elapsed() steps_per_second = interval_steps / elapsed batch_size = melt.batch_size() num_gpus = melt.num_gpus() instances_per_second = interval_steps * batch_size / elapsed gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format( num_gpus) if num_steps_per_epoch is None: epoch_time_info = '' else: hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600 epoch_time_info = ' 1epoch:[{:.2f}h]'.format( hours_per_epoch) info.write( 'elapsed:[{:.3f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.8f}]' .format(elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate)) if print_avg_loss: #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names))) names_ = melt.adjust_names(train_average_loss, names) #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_))) info.write(' train:{} '.format( melt.parse_results(train_average_loss, names_))) #info.write('train_avg_loss: {} '.format(train_average_loss)) #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ') logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step, info.getvalue())) if deal_results_fn is not None: stop = deal_results_fn(results) summary_strs = gezi.to_list(summary_str) if metric_evaluate: if evaluate_summaries is not None: summary_strs += evaluate_summaries if step > 1: if is_eval_step: # deal with summary if log_dir: # if not hasattr(train_once, 'summary_op'): # melt.print_summary_ops() # if summary_excls is None: # train_once.summary_op = tf.summary.merge_all() # else: # summary_ops = [] # for op in tf.get_collection(tf.GraphKeys.SUMMARIES): # for summary_excl in summary_excls: # if not summary_excl in op.name: # summary_ops.append(op) # print('filtered summary_ops:') # for op in summary_ops: # print(op) # train_once.summary_op = tf.summary.merge(summary_ops) # print('-------------summary_op', train_once.summary_op) # #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN) # train_once.summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # tf.contrib.tensorboard.plugins.projector.visualize_embeddings(train_once.summary_writer, projector_config) summary = tf.Summary() # #so the strategy is on eval_interval_steps, if has eval dataset, then tensorboard evluate on eval dataset # #if not have eval dataset, will evaluate on trainset, but if has eval dataset we will also monitor train loss # assert train_once.summary_train_op is None # if train_once.summary_train_op is not None: # summary_str = sess.run(train_once.summary_train_op, feed_dict=feed_dict) # train_once.summary_writer.add_summary(summary_str, step) if eval_ops is None: # #get train loss, for every batch train # if train_once.summary_op is not None: # #timer2 = gezi.Timer('sess run') # try: # # TODO FIXME so this means one more train batch step without adding to global step counter ?! so should move it earlier # summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' # # #timer2.print() if train_once.summary_op is not None: for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) else: # #get eval loss for every batch eval, then add train loss for eval step average loss # try: # summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) if train_once.summary_op is not None else '' # except Exception: # if not hasattr(train_once, 'num_summary_errors'): # logging.warning('summary_str = sess.run(train_once.summary_op, feed_dict=eval_feed_dict) fail') # train_once.num_summary_errors = 1 # logging.warning(traceback.format_exc()) # summary_str = '' #all single value results will be add to summary here not using tf.scalar_summary.. #summary.ParseFromString(summary_str) for summary_str in summary_strs: train_once.summary_writer.add_summary( summary_str, step) suffix = 'eval' if not eval_names else '' melt.add_summarys(summary, eval_results, eval_names_, suffix=suffix) if ops is not None: melt.add_summarys(summary, train_average_loss, names_, suffix='train_avg') ##optimizer has done this also melt.add_summary(summary, learning_rate, 'learning_rate') melt.add_summary(summary, melt.batch_size(), 'batch_size') melt.add_summary(summary, melt.epoch(), 'epoch') if steps_per_second: melt.add_summary(summary, steps_per_second, 'steps_per_second') if instances_per_second: melt.add_summary(summary, instances_per_second, 'instances_per_second') if hours_per_epoch: melt.add_summary(summary, hours_per_epoch, 'hours_per_epoch') if metric_evaluate: #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): train_once.epoch_step = 1 else: train_once.epoch_step += 1 step = train_once.epoch_step melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) train_once.summary_writer.add_summary(summary, step) train_once.summary_writer.flush() #timer_.print() # if print_time: # full_duration = train_once.eval_timer.elapsed() # if metric_evaluate: # metric_full_duration = train_once.metric_eval_timer.elapsed() # full_duration_str = 'elapsed:{:.3f} '.format(full_duration) # #info.write('duration:{:.3f} '.format(timer.elapsed())) # duration = timer.elapsed() # info.write('duration:{:.3f} '.format(duration)) # info.write(full_duration_str) # info.write('eval_time_ratio:{:.3f} '.format(duration/full_duration)) # if metric_evaluate: # info.write('metric_time_ratio:{:.3f} '.format(duration/metric_full_duration)) # #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, info.getvalue()) # logging.info2('{} {} {}'.format(epoch_str, 'eval_step: %d'%step, info.getvalue())) return stop elif metric_evaluate: summary = tf.Summary() for summary_str in summary_strs: train_once.summary_writer.add_summary(summary_str, step) #summary.ParseFromString(evaluate_summaries) summary_writer = train_once.summary_writer prefix = 'step/valid' if model_path: prefix = 'epoch/valid' if not hasattr(train_once, 'epoch_step'): ## TODO.. restart will get 1 again.. #epoch_step = tf.Variable(0, trainable=False, name='epoch_step') #epoch_step += 1 #train_once.epoch_step = sess.run(epoch_step) valid_interval_epochs = 1. try: valid_interval_epochs = FLAGS.valid_interval_epochs except Exception: pass train_once.epoch_step = 1 if melt.epoch() <= 1 else int( int(melt.epoch() * 10) / int(valid_interval_epochs * 10)) logging.info('train_once epoch start step is', train_once.epoch_step) else: #epoch_step += 1 train_once.epoch_step += 1 step = train_once.epoch_step #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval') melt.add_summarys(summary, evaluate_results, evaluate_names, prefix=prefix) summary_writer.add_summary(summary, step) summary_writer.flush()
def evaluate_translation(predictor, random=False, index=None): timer = gezi.Timer('evaluate_translation') refs = prepare_refs() imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) if num_metric_eval_examples <= 0: num_metric_eval_examples = len(imgs) if num_metric_eval_examples == len(imgs): random = False step = FLAGS.metric_eval_batch_size if random: if index is None: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] else: img_features = img_features[:num_metric_eval_examples] results = {} start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r') translation_predicts(imgs[start: end], img_features[start: end], predictor, results) start = end scorers = [ (Bleu(4), ["bleu_1", "bleu_2", "bleu_3", "bleu_4"]), (Meteor(),"meteor"), (Rouge(), "rouge_l"), (Cider(), "cider") ] score_list = [] metric_list = [] selected_refs = {} selected_results = {} #by doing this can force same .keys() for key in results: selected_refs[key] = refs[key] selected_results[key] = results[key] assert len(selected_results[key]) == 1, selected_results[key] assert selected_results.keys() == selected_refs.keys(), '%d %d'%(len(selected_results.keys()), len(selected_refs.keys())) if FLAGS.eval_translation_reseg: print('tokenization...', file=sys.stderr) global tokenizer if tokenizer is None: tokenizer = PTBTokenizer() selected_refs = tokenizer.tokenize(selected_refs) selected_results = tokenizer.tokenize(selected_results) logging.info('predict&label:{}{}{}'.format('|'.join(selected_results.items()[0][1]), '---', '|'.join(selected_refs.items()[0][1]))) for scorer, method in scorers: print('computing %s score...'%(scorer.method()), file=sys.stderr) score, scores = scorer.compute_score(selected_refs, selected_results) if type(method) == list: for sc, scs, m in zip(score, scores, method): score_list.append(sc) metric_list.append(m) if FLAGS.eval_result_dir: out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w') for i, sc in enumerate(scs): key = selected_results.keys()[i] result = selected_results[key] refs = '\x01'.join(selected_refs[key]) print(key, result, refs, sc, sep='\t', file=out) else: score_list.append(score) metric_list.append(method) if FLAGS.eval_result_dir: out = open(os.path.join(FLAGS.eval_result_dir, m+'.txt'), 'w') for i, sc in enumerate(scores): key = selected_results.keys()[i] result = selected_results[key] refs = '\x01'.join(selected_refs[key]) print(key, result, refs, sc, sep='\t', file=out) #exclude "bleu_1", "bleu_2", "bleu_3" score_list, metric_list = score_list[3:], metric_list[3:] assert(len(score_list) == 4) avg_score = sum(score_list) / len(score_list) score_list.append(avg_score) metric_list.append('avg') metric_list = ['trans_' + x for x in metric_list] melt.logging_results( score_list, metric_list, tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) timer.print() return score_list, metric_list
def evaluate_scores(predictor, random=False, index=None, exact_predictor=None, exact_ratio=1.): """ actually this is rank metrics evaluation, by default recall@1,2,5,10,50 """ timer = gezi.Timer('evaluate_scores') init() if FLAGS.eval_img2text: imgs, img_features = get_image_names_and_features() num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(imgs)) if num_metric_eval_examples <= 0: num_metric_eval_examples = len(imgs) if num_metric_eval_examples == len(imgs): random = False step = FLAGS.metric_eval_batch_size if random: if index is None: index = np.random.choice(len(imgs), num_metric_eval_examples, replace=False) imgs = imgs[index] img_features = img_features[index] else: img_features = img_features[:num_metric_eval_examples] rank_metrics = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts image start:', start, 'end:', end, file=sys.stderr, end='\r') predicts(imgs[start: end], img_features[start: end], predictor, rank_metrics, exact_predictor=exact_predictor, exact_ratio=exact_ratio) start = end melt.logging_results( rank_metrics.get_metrics(), rank_metrics.get_names(), tag='evaluate: epoch:{} step:{} train:{} eval:{}'.format( melt.epoch(), melt.step(), melt.train_loss(), melt.eval_loss())) if FLAGS.eval_text2img: num_metric_eval_examples = min(FLAGS.num_metric_eval_examples, len(all_distinct_texts)) if random: index = np.random.choice(len(all_distinct_texts), num_metric_eval_examples, replace=False) text_strs = all_distinct_text_strs[index] texts = all_distinct_texts[index] else: text_strs = all_distinct_text_strs texts = all_distinct_texts rank_metrics2 = gezi.rank_metrics.RecallMetrics() start = 0 while start < num_metric_eval_examples: end = start + step if end > num_metric_eval_examples: end = num_metric_eval_examples print('predicts start:', start, 'end:', end, file=sys.stderr, end='\r') predicts_txt2im(text_strs[start: end], texts[start: end], predictor, rank_metrics2, exact_predictor=exact_predictor) start = end melt.logging_results( rank_metrics2.get_metrics(), ['t2i' + x for x in rank_metrics2.get_names()], tag='text2img') timer.print() if FLAGS.eval_img2text and FLAGS.eval_text2img: return rank_metrics.get_metrics() + rank_metrics2.get_metrics(), rank_metrics.get_names() + ['t2i' + x for x in rank_metrics2.get_names()] elif FLAGS.eval_img2text: return rank_metrics.get_metrics(), rank_metrics.get_names() else: return rank_metrics2.get_metrics(), rank_metrics2.get_names()