def _calc_mem_scores(model, ckpt_path, dl, ds, batch_size, device, preds_savepath, n_times): alphas = [] mems = [] gt_mems = None gt_alphas = None for i in range(n_times): print("Generating mem scores, round {}".format(i)) preds: Optional[ModelOutput] = None labels: Optional[ModelOutput] = None with torch.no_grad(): for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size): y: ModelOutput[MemModelFields] = ModelOutput(y_) y_list = y.to_numpy() labels = y_list if labels is None else labels.merge(y_list) x = x.to(device) y = y.to_device(device) out = ModelOutput(model(x, y.get_data())) out_list = out.to_device('cpu').to_numpy() preds = out_list if preds is None else preds.merge(out_list) mems.append(preds['score']) alphas.append(preds['alpha']) print("correlation", spearmanr(preds['score'], labels['score'])) if gt_mems is None: gt_mems = labels['score'] gt_alphas = labels['alpha'] # merge mem scores mems_avg = np.array(mems).mean(axis=0) alphas_avg = np.array(alphas).mean(axis=0) rc_value = spearmanr(mems_avg, gt_mems) print("rc", rc_value) metrics = {'rc': rc_value.correlation} data = { 'ckpt': ckpt_path, 'mems': mems_avg.tolist(), 'alphas': alphas_avg.tolist(), 'gt_mems': gt_mems.tolist(), 'gt_alphas': gt_alphas.tolist(), 'metrics': metrics } with open(preds_savepath, "w") as outfile: print("Saving results") json.dump(data, outfile)
def __getitem__(self, vidpath) -> ModelOutput[MemCapModelFields]: viddata = self.data_for_vidpath(vidpath) score, alpha = self.factor * viddata[ 'mem_score'], self.factor * viddata['alpha'] cap_data = self.cap_data[self.vidname_from_path(vidpath)] cap_i = random.randint(0, len(cap_data['indexed_captions']) - 1) cap, tokenized_cap = cap_data['indexed_captions'][cap_i], cap_data[ 'tokenized_captions'][cap_i] cap_in, cap_out = transform_caption( cap, tokenized_cap, input_format="embedding_list", caption_format="one_hot_list", add_padding=True, word_embeddings=self.word_embedding, max_cap_len=cfg.MAX_CAP_LEN, vocab_size=cfg.VOCAB_SIZE) return ModelOutput({ 'score': score, 'alpha': alpha, 'in_captions': np.array(cap_in).astype(np.float32, copy=False), 'out_captions': cap_out.astype(np.float32, copy=False), })
def __getitem__(self, vidpath) -> ModelOutput[MemModelFields]: viddata = self.data_for_vidpath(vidpath) mem_score = self.factor * viddata['mem_score'] alpha = self.factor * viddata['alpha'] return ModelOutput({'score': mem_score, 'alpha': alpha})
def calc_captions(model, dl, ds, batch_size, device, vocab_embedding, idx2word, preds_savepath, fnames): assert batch_size == 1 captions = {} with torch.no_grad(): for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size): y: ModelOutput[MemModelFields] = ModelOutput(y_) x = x.to(device) y = y.to_device(device) words = cap_utils.predict_captions_simple(model, x, device, vocab_embedding, idx2word) # words_beam = cap_utils.predict_captions_beam(model, x, device, # vocab_embedding, idx2word) fname = fnames[i] captions[fname] = words print("simple", words) # print("beam", words_beam) with open(preds_savepath, "w") as outfile: print("saving results") json.dump(captions, outfile)
def main(verbose: int = 1, print_freq: int = 100, restore: Union[bool, str] = True, val_freq: int = 1, run_id: str = "model", dset_name: str = "memento_frames", model_name: str = "frames", freeze_until_it: int = 1000, additional_metrics: Mapping[str, Callable] = {'rc': rc}, debug_n: Optional[int] = None, batch_size: int = cfg.BATCH_SIZE, require_strict_model_load: bool = False, restore_optimizer=True, optim_string='adam', lr=0.01) -> None: print("TRAINING MODEL {} ON DATASET {}".format(model_name, dset_name)) ckpt_savedir = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.CKPT_DIR) print("Saving ckpts to {}".format(ckpt_savedir)) logs_savepath = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.LOGDIR) print("Saving logs to {}".format(logs_savepath)) utils.makedirs([ckpt_savedir, logs_savepath]) last_ckpt_path = os.path.join(ckpt_savedir, "last_model.pth") device = utils.set_device() print('DEVICE', device) # model model = get_model(model_name, device) # print("model", model) model = DataParallel(model) # must call this before constructing the optimizer: # https://pytorch.org/docs/stable/optim.html model.to(device) # set up training # TODO better one? if optim_string == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr) elif optim_string == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001) else: raise RuntimeError( "Unrecognized optimizer string {}".format(optim_string)) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # criterion = MemAlphaLoss(device=device) # criterion = MemMSELoss() # criterion = lambda x, y: MemMSELoss()(x, y) + # CaptionsLoss(device=device)(x, y) losses = { 'mem_mse': MemMSELoss(device=device, weights=np.load("memento_weights.npy")), 'captions': CaptionsLoss(device=device, class_weights=cap_utils.get_vocab_weights()) } initial_epoch = 0 iteration = 0 unfrozen = False if restore: ckpt_path = restore if isinstance(restore, str) else last_ckpt_path if os.path.exists(ckpt_path): print("Restoring weights from {}".format(ckpt_path)) ckpt = torch.load(ckpt_path) utils.try_load_state_dict(model, ckpt['model_state_dict'], require_strict_model_load) if restore_optimizer: utils.try_load_optim_state(optimizer, ckpt['optimizer_state_dict'], require_strict_model_load) initial_epoch = ckpt['epoch'] iteration = ckpt['it'] else: ckpt_path = last_ckpt_path # dataset train_ds, val_ds, test_ds = get_dataset(dset_name) assert val_ds or test_ds if debug_n is not None: train_ds = Subset(train_ds, range(debug_n)) test_ds = Subset(test_ds, range(debug_n)) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=cfg.NUM_WORKERS) test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=cfg.NUM_WORKERS) # training loop start = time.time() try: for epoch in range(initial_epoch, cfg.NUM_EPOCHS): logger = SummaryWriter(logs_savepath) # effectively puts the model in train mode. # Opposite of model.eval() model.train() print("Epoch {}".format(epoch)) for i, (x, y_) in tqdm(enumerate(train_dl), total=len(train_ds) / batch_size): y: ModelOutput[MemModelFields] = ModelOutput(y_) iteration += 1 if not unfrozen and iteration > freeze_until_it: print("Unfreezing encoder") unfrozen = True for param in model.parameters(): param.requires_grad = True logger.add_scalar('DataTime', time.time() - start, iteration) x = x.to(device) y = y.to_device(device) out = ModelOutput(model(x, y.get_data())) loss_vals = {name: l(out, y) for name, l in losses.items()} # print("loss_vals", loss_vals) loss = torch.stack(list(loss_vals.values())) if verbose: print("stacked loss", loss) loss = loss.sum() # loss = criterion(out, y) # I think this zeros out previous gradients (in case people # want to accumulate gradients?) optimizer.zero_grad() loss.backward() optimizer.step() # logging utils.log_loss(logger, loss, loss_vals, iteration) logger.add_scalar('ItTime', time.time() - start, iteration) start = time.time() # display metrics # do some validation if (epoch + 1) % val_freq == 0: print("Validating...") model.eval() # puts model in validation mode val_iteration = iteration with torch.no_grad(): labels: Optional[ModelOutput[MemModelFields]] = None preds: Optional[ModelOutput[MemModelFields]] = None val_losses = [] for i, (x, y_) in tqdm(enumerate(test_dl), total=len(test_ds) / batch_size): val_iteration += 1 y = ModelOutput(y_) y_numpy = y.to_numpy() labels = y_numpy if labels is None else labels.merge( y_numpy) x = x.to(device) y = y.to_device(device) out = ModelOutput(model(x, y.get_data())) out_numpy = out.to_device('cpu').to_numpy() preds = out_numpy if preds is None else preds.merge( out_numpy) loss_vals = { name: l(out, y) for name, l in losses.items() } loss = torch.stack(list(loss_vals.values())).sum() utils.log_loss(logger, loss, loss_vals, val_iteration, phase='val') val_losses.append(loss) print("Calculating validation metric...") # print("preds", {k: v.shape for k, v in preds.items()}) # assert False metrics = { fname: f(labels, preds, losses) for fname, f in additional_metrics.items() } print("Validation metrics", metrics) for k, v in metrics.items(): if isinstance(v, numbers.Number): logger.add_scalar('Metric_{}'.format(k), v, iteration) metrics['total_val_loss'] = sum(val_losses) ckpt_path = os.path.join( ckpt_savedir, utils.get_ckpt_path(epoch, metrics)) save_ckpt(ckpt_path, model, epoch, iteration, optimizer, dset_name, model_name, metrics) # end of epoch lr_scheduler.step() save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer, dset_name, model_name) except KeyboardInterrupt: print('Got keyboard interrupt, saving model...') save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer, dset_name, model_name)
def __call__(self, sample: ModelOutput): data = {k: self.transform([v])[0] for k, v in sample.items()} return ModelOutput(data)
def predict(ckpt_path, metrics: Mapping[str, Callable] = {'rc': rc}, num_workers: int = 20, use_gpu: bool = True, model_name: str = "resnet3d", dset_name: str = "memento_frames", batch_size: int = 1, preds_savepath: Optional[str] = None, use_val: bool = False, debug_n: Optional[int] = None, final_activation: str = 'relu', shuffle=False): print("ckpt path: {}".format(ckpt_path)) if preds_savepath is None: preds_savepath = os.path.splitext( ckpt_path.replace(cfg.CKPT_DIR, cfg.PREDS_DIR))[0] + '.json' utils.makedirs([os.path.dirname(preds_savepath)]) print("preds savepath: {}".format(preds_savepath)) device = utils.set_device() print('DEVICE', device) # load the ckpt print("Loading model from path: {}".format(ckpt_path)) ckpt = torch.load(ckpt_path) # model model = nn.DataParallel(MemRestNet3D(final_activation=final_activation)) model.load_state_dict(ckpt['model_state_dict'], strict=True) model.to(device) model.eval() print("USING MODEL TYPE {} ON DSET {}".format(model_name, dset_name)) # data loader train, val, test = get_dataset() ds = val if use_val else test if ds is None: raise ValueError("No {} set available for this dataset.".format( "val" if use_val else "test")) else: print("Using {} set".format("val" if use_val else "test")) if debug_n is not None: ds = Subset(ds, range(debug_n)) dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) preds: Optional[ModelOutput] = None labels: Optional[ModelOutput] = None with torch.no_grad(): for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size): y: ModelOutput[MemModelFields] = ModelOutput(y_) y_list = y.to_numpy() labels = y_list if labels is None else labels.merge(y_list) x = x.to(device) y = y.to_device(device) out = ModelOutput(model(x, y.get_data())) out_list = out.to_device('cpu').to_numpy() preds = out_list if preds is None else preds.merge(out_list) metrics = {fname: f(labels, preds, None) for fname, f in metrics.items()} print("METRICS", metrics) data = { 'ckpt': ckpt_path, 'preds': preds.to_list().get_data(), 'labels': labels.to_list().get_data(), 'metrics': metrics } with open(preds_savepath, "w") as outfile: print("Saving results") json.dump(data, outfile) return metrics