def train(model, train_loader, epoch): # average meters to record the training statistics batch_time = util.AverageMeter() data_time = util.AverageMeter() # switch to train mode model.switch_to_train() progbar = Progbar(len(train_loader.dataset)) end = time.time() for i, train_data in enumerate(train_loader): data_time.update(time.time() - end) vis_input, txt_input, _, _, _ = train_data loss = model.train(vis_input, txt_input) progbar.add(vis_input.size(0), values=[('data_time', data_time.val), ('batch_time', batch_time.val), ('loss', loss)]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # Record logs in tensorboard writer.add_scalar('train/Loss', loss, model.iters)
def encode_data(model, data_loader): """Encode all images and captions loadable by `data_loader` """ model.switch_to_eval() vis_embs = None txt_embs = None vis_ids = [''] * len(data_loader.dataset) txt_ids = [''] * len(data_loader.dataset) pbar = Progbar(len(data_loader.dataset)) for i, (vis_input, txt_input, idxs, batch_vis_ids, batch_txt_ids) in enumerate(data_loader): with torch.no_grad(): vis_emb = model.vis_net(vis_input) txt_emb = model.txt_net(txt_input) if vis_embs is None: vis_embs = np.zeros((len(data_loader.dataset), vis_emb.size(1))) txt_embs = np.zeros((len(data_loader.dataset), txt_emb.size(1))) vis_embs[idxs] = vis_emb.data.cpu().numpy().copy() txt_embs[idxs] = txt_emb.data.cpu().numpy().copy() for j, idx in enumerate(idxs): txt_ids[idx] = batch_txt_ids[j] vis_ids[idx] = batch_vis_ids[j] pbar.add(vis_emb.size(0)) return vis_embs, txt_embs, vis_ids, txt_ids
def encode_vis(model, data_loader): model.switch_to_eval() vis_embs = None vis_ids = [''] * len(data_loader.dataset) pbar = Progbar(len(data_loader.dataset)) for i, (vis_input, idxs, batch_vis_ids) in enumerate(data_loader): with torch.no_grad(): vis_emb = model.vis_net(vis_input) if vis_embs is None: vis_embs = np.zeros((len(data_loader.dataset), vis_emb.size(1))) vis_embs[list(idxs)] = vis_emb.data.cpu().numpy().copy() for j, idx in enumerate(idxs): vis_ids[idx] = batch_vis_ids[j] pbar.add(len(idxs)) return vis_embs, vis_ids
def encode_txt(model, data_loader): model.switch_to_eval() txt_embs = None txt_ids = [''] * len(data_loader.dataset) pbar = Progbar(len(data_loader.dataset)) for i, (txt_input, idxs, batch_txt_ids) in enumerate(data_loader): with torch.no_grad(): txt_emb = model.txt_net(txt_input) if txt_embs is None: txt_embs = np.zeros((len(data_loader.dataset), txt_emb.size(1))) txt_embs[idxs] = txt_emb.data.cpu().numpy().copy() for j, idx in enumerate(idxs): txt_ids[idx] = batch_txt_ids[j] pbar.add(len(idxs)) return txt_embs, txt_ids