def split_epoch_end(self, outputs, split='val'): outputs = d2comm.gather(outputs) # master node if d2comm.is_main_process(): assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 outputs = sum(outputs, []) opt = self.opt loss_mean = sum([_['loss'].item() for _ in outputs]) / len(outputs) predictions = sum([_['predictions'] for _ in outputs], []) if len(outputs[0]['n_predictions']) != 0: n_predictions = sum([_['n_predictions'] for _ in outputs], []) else: n_predictions = [] lang_stats = None if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) if not os.path.isdir('eval_results'): os.mkdir('eval_results') torch.save( (predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth')) if opt.language_eval: lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), split) if opt.reduce_on_plateau: optimizer = self.trainer.optimizers[0] if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(loss_mean) out = {'loss': loss_mean} out.update(lang_stats) out['to_monitor'] = lang_stats[ 'CIDEr'] if lang_stats is not None else -loss_mean else: out = {} out = d2comm.all_gather(out)[0] # Only the one from master node assert len(out) > 0 # make sure the head has index 0 # must all be tensors out = { k: torch.tensor(v) if not torch.is_tensor(v) else v for k, v in out.items() } return out
result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): # if results existed, then skip, unless force is on if not opt.force: try: if os.path.isfile(result_fn): print(result_fn) json.load(open(result_fn, 'r')) print('already evaluated') os._exit(0) except: pass predictions, n_predictions = torch.load(pred_fn) lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) print(lang_stats) os._exit(0) # At this point only_lang_eval if 0 if not opt.force: # Check out if try: # if no pred exists, then continue tmp = torch.load(pred_fn) # if language_eval == 1, and no pred exists, then continue if opt.language_eval == 1: json.load(open(result_fn, 'r')) print('Result is already there') os._exit(0) except: