def __init__(self, opt): super().__init__() self.opt = opt # Intilaize dataset self.dataset = CaptionDataset(opt) opt.vocab_size = self.dataset.vocab_size opt.seq_length = self.dataset.seq_length self.batch_size = opt.batch_size # Build model opt.vocab = self.dataset.get_vocab() model = models.setup(opt) print(model) del opt.vocab # wrapper with loss in it. lw_model = LossWrapper(model, opt) self.model = model self.lw_model = lw_model self.struc_flag = None self.sc_flag = None
def test_folder(): x = pickle_load(open('log_trans/infos_trans.pkl', 'rb')) dataset = CaptionDataset(x['opt']) ds = torch.utils.data.Subset(dataset, dataset.split_ix['train']) ds[0]
class LitModel(pl.LightningModule): def __init__(self, opt): super().__init__() self.opt = opt # Intilaize dataset self.dataset = CaptionDataset(opt) opt.vocab_size = self.dataset.vocab_size opt.seq_length = self.dataset.seq_length self.batch_size = opt.batch_size # Build model opt.vocab = self.dataset.get_vocab() model = models.setup(opt) print(model) del opt.vocab # wrapper with loss in it. lw_model = LossWrapper(model, opt) self.model = model self.lw_model = lw_model self.struc_flag = None self.sc_flag = None def forward(self, *args, **kwargs): """ I hate this design. Never pretend it as a nn.Module """ raise NotImplementedError def train_dataloader(self): train_dataset = torch.utils.data.Subset(self.dataset, self.dataset.split_ix['train']) train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, collate_fn=self.dataset.collate_func) return train_loader def val_dataloader(self, split='val'): val_dataset = torch.utils.data.Subset(self.dataset, self.dataset.split_ix[split]) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, drop_last=False, collate_fn=self.dataset.collate_func) return val_loader def test_dataloader(self): return self.val_dataloader('test') def training_step(self, data, batch_idx): sc_flag, struc_flag, drop_worst_flag = self.sc_flag, self.struc_flag, self.drop_worst_flag tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] fc_feats, att_feats, labels, masks, att_masks = tmp if int(os.getenv('M2_cider', '0')) != 0: data['gts'] = data['rawgts'] model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag, drop_worst_flag) if not drop_worst_flag: loss = model_out.pop('loss').mean() else: loss = model_out.pop('loss') loss = torch.topk(loss, k=int(loss.shape[0] * (1 - self.opt.drop_worst_rate)), largest=False)[0].mean() # Prepare for logging info data_time = self.trainer.profiler.recorded_durations[ "get_train_batch"][-1] data_time = torch.tensor(data_time) logger_logs = model_out.copy() if struc_flag or sc_flag: logger_logs['reward'] = model_out['reward'].mean() logger_logs['reward_var'] = model_out['reward'].var(1).mean() logger_logs['scheduled_sampling_prob'] = torch.tensor( self.model.ss_prob) logger_logs['training_loss'] = loss logger_logs['data_time'] = data_time for k, v in logger_logs.items(): self.log(k, v, on_epoch=(k == 'training_loss'), prog_bar=(k == 'data_time')) # logged return loss def validation_step(self, data, batch_idx): model = self.model crit = self.lw_model.crit opt = self.opt eval_kwargs = {'dataset': opt.input_json} eval_kwargs.update(vars(opt)) verbose = eval_kwargs.get('verbose', True) verbose_beam = eval_kwargs.get('verbose_beam', 0) verbose_loss = eval_kwargs.get('verbose_loss', 1) # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) # lang_eval = eval_kwargs.get('language_eval', 0) dataset = eval_kwargs.get('dataset', 'coco') beam_size = eval_kwargs.get('beam_size', 1) sample_n = eval_kwargs.get('sample_n', 1) remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) # Use this nasty way to make other code clean since it's a global configuration os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) predictions = [] n_predictions = [] loss = torch.tensor(0) tmp = [ data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'] ] fc_feats, att_feats, labels, masks, att_masks = tmp if data.get('labels', None) is not None and verbose_loss: # forward the model to get loss loss = crit( model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) # forward the model to also get generated samples for each image # Only leave one feature for each image, in case duplicate sample tmp_eval_kwargs = eval_kwargs.copy() tmp_eval_kwargs.update({'sample_n': 1}) seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') seq = seq.data entropy = -(F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ( (seq > 0).to(seq_logprobs).sum(1) + 1) perplexity = - \ seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) # Print beam search if beam_size > 1 and verbose_beam: for i in range(fc_feats.shape[0]): print('\n'.join([ utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i] ])) print('--' * 10) sents = utils.decode_sequence(model.vocab, seq) for k, sent in enumerate(sents): entry = { 'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item() } if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] predictions.append(entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \ '" vis/imgs/img' + \ str(len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) if verbose: print('image %s: %s' % (entry['image_id'], entry['caption'])) if sample_n > 1: eval_utils.eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) output = { 'val_loss': loss, 'predictions': predictions, 'n_predictions': n_predictions, } self.log('val_loss', loss) return output def test_step(self, *args, **kwargs): return self.validation_step(*args, **kwargs) def validation_epoch_end(self, outputs, split='val'): outputs = d2comm.gather(outputs) # master node if d2comm.is_main_process(): assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 outputs = sum(outputs, []) opt = self.opt val_loss_mean = sum([_['val_loss'].item() for _ in outputs]) / len(outputs) predictions = sum([_['predictions'] for _ in outputs], []) if len(outputs[0]['n_predictions']) != 0: n_predictions = sum([_['n_predictions'] for _ in outputs], []) else: n_predictions = [] lang_stats = None if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) if not os.path.isdir('eval_results'): os.mkdir('eval_results') torch.save( (predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth')) if opt.language_eval: lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), split) if opt.reduce_on_plateau: optimizer = self.trainer.optimizers[0] if 'CIDEr' in lang_stats: optimizer.scheduler_step(-lang_stats['CIDEr']) else: optimizer.scheduler_step(val_loss_mean) out = {'val_loss': val_loss_mean} out.update(lang_stats) out['to_monitor'] = lang_stats[ 'CIDEr'] if lang_stats is not None else -val_loss_mean else: out = {} out = d2comm.all_gather(out)[0] # Only the one from master node assert len(out) > 0 # make sure the head has index 0 # must all be tensors out = { k: torch.tensor(v) if not torch.is_tensor(v) else v for k, v in out.items() } for k, v in out.items(): self.log(k, v) return out def test_epoch_end(self, outputs): out = self.validation_epoch_end(outputs, 'test') out['test_loss'] = out['val_loss'] del out['val_loss'] del out['to_monitor'] out = { 'test_' + k if 'test' not in k else k: v for k, v in out.items() } return out def configure_optimizers(self): opt = self.opt model = self.model if opt.noamopt: # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) elif opt.reduce_on_plateau: optimizer = utils.build_optimizer(model.parameters(), opt) optimizer = utils.ReduceLROnPlateau( optimizer, factor=opt.reduce_on_plateau_factor, patience=opt.reduce_on_plateau_patience) else: optimizer = utils.build_optimizer(model.parameters(), opt) return [optimizer], [] def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs): # warm up lr opt = self.opt iteration = self.trainer.global_step if opt.use_warmup and (iteration < opt.noamopt_warmup): opt.current_lr = opt.learning_rate * \ (iteration+1) / opt.noamopt_warmup utils.set_lr(optimizer, opt.current_lr) super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs) def state_dict(self, *args, **kwargs): """ Save the model state dict as well as opt and vocab """ state_dict = self.model.state_dict(*args, **kwargs) device = next(iter(state_dict.values())).device assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' state_dict.update({ '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), '_opt': utils.serialize_to_tensor(self.opt).to(device) }) return state_dict def load_state_dict(self, state_dict=None, strict=True): if '_vocab' in state_dict: self.model.vocab = utils.deserialize(state_dict['_vocab']) del state_dict['_vocab'] elif strict: raise KeyError if '_opt' in state_dict: saved_model_opt = utils.deserialize(state_dict['_opt']) del state_dict['_opt'] opt = self.opt # Make sure the saved opt is compatible with the curren topt need_be_same = [ "caption_model", "rnn_type", "rnn_size", "num_layers" ] for checkme in need_be_same: if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ getattr(opt, checkme) in ['updown', 'topdown']: continue assert getattr(saved_model_opt, checkme) == getattr( opt, checkme ), "Command line argument and saved model disagree on '%s' " % checkme elif strict: raise KeyError self.model.load_state_dict(state_dict, strict) def get_progress_bar_dict(self): # don't show the version number items = super().get_progress_bar_dict() items.pop("v_num", None) return items
def test_lmdb(): x = pickle_load(open('log_trans/infos_trans.pkl', 'rb')) x['opt'].input_att_dir = 'data/vilbert_att.lmdb' dataset = CaptionDataset(x['opt']) ds = torch.utils.data.Subset(dataset, dataset.split_ix['train']) ds[0]