예제 #1
0
def _calc_mem_scores(model, ckpt_path, dl, ds, batch_size, device,
                     preds_savepath, n_times):

    alphas = []
    mems = []
    gt_mems = None
    gt_alphas = None

    for i in range(n_times):
        print("Generating mem scores, round {}".format(i))

        preds: Optional[ModelOutput] = None
        labels: Optional[ModelOutput] = None
        with torch.no_grad():
            for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size):

                y: ModelOutput[MemModelFields] = ModelOutput(y_)
                y_list = y.to_numpy()
                labels = y_list if labels is None else labels.merge(y_list)

                x = x.to(device)
                y = y.to_device(device)

                out = ModelOutput(model(x, y.get_data()))

                out_list = out.to_device('cpu').to_numpy()
                preds = out_list if preds is None else preds.merge(out_list)

        mems.append(preds['score'])
        alphas.append(preds['alpha'])

        print("correlation", spearmanr(preds['score'], labels['score']))

        if gt_mems is None:
            gt_mems = labels['score']
            gt_alphas = labels['alpha']

    # merge mem scores
    mems_avg = np.array(mems).mean(axis=0)
    alphas_avg = np.array(alphas).mean(axis=0)

    rc_value = spearmanr(mems_avg, gt_mems)
    print("rc", rc_value)

    metrics = {'rc': rc_value.correlation}

    data = {
        'ckpt': ckpt_path,
        'mems': mems_avg.tolist(),
        'alphas': alphas_avg.tolist(),
        'gt_mems': gt_mems.tolist(),
        'gt_alphas': gt_alphas.tolist(),
        'metrics': metrics
    }

    with open(preds_savepath, "w") as outfile:
        print("Saving results")
        json.dump(data, outfile)
예제 #2
0
    def __getitem__(self, vidpath) -> ModelOutput[MemCapModelFields]:
        viddata = self.data_for_vidpath(vidpath)
        score, alpha = self.factor * viddata[
            'mem_score'], self.factor * viddata['alpha']

        cap_data = self.cap_data[self.vidname_from_path(vidpath)]
        cap_i = random.randint(0, len(cap_data['indexed_captions']) - 1)
        cap, tokenized_cap = cap_data['indexed_captions'][cap_i], cap_data[
            'tokenized_captions'][cap_i]

        cap_in, cap_out = transform_caption(
            cap,
            tokenized_cap,
            input_format="embedding_list",
            caption_format="one_hot_list",
            add_padding=True,
            word_embeddings=self.word_embedding,
            max_cap_len=cfg.MAX_CAP_LEN,
            vocab_size=cfg.VOCAB_SIZE)

        return ModelOutput({
            'score':
            score,
            'alpha':
            alpha,
            'in_captions':
            np.array(cap_in).astype(np.float32, copy=False),
            'out_captions':
            cap_out.astype(np.float32, copy=False),
        })
예제 #3
0
    def __getitem__(self, vidpath) -> ModelOutput[MemModelFields]:
        viddata = self.data_for_vidpath(vidpath)

        mem_score = self.factor * viddata['mem_score']
        alpha = self.factor * viddata['alpha']

        return ModelOutput({'score': mem_score, 'alpha': alpha})
예제 #4
0
def calc_captions(model, dl, ds, batch_size, device, vocab_embedding, idx2word,
                  preds_savepath, fnames):
    assert batch_size == 1
    captions = {}
    with torch.no_grad():
        for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size):

            y: ModelOutput[MemModelFields] = ModelOutput(y_)

            x = x.to(device)
            y = y.to_device(device)

            words = cap_utils.predict_captions_simple(model, x, device,
                                                      vocab_embedding,
                                                      idx2word)
            # words_beam = cap_utils.predict_captions_beam(model, x, device,
            #                                    vocab_embedding, idx2word)

            fname = fnames[i]
            captions[fname] = words
            print("simple", words)
            # print("beam", words_beam)

    with open(preds_savepath, "w") as outfile:
        print("saving results")
        json.dump(captions, outfile)
예제 #5
0
def main(verbose: int = 1,
         print_freq: int = 100,
         restore: Union[bool, str] = True,
         val_freq: int = 1,
         run_id: str = "model",
         dset_name: str = "memento_frames",
         model_name: str = "frames",
         freeze_until_it: int = 1000,
         additional_metrics: Mapping[str, Callable] = {'rc': rc},
         debug_n: Optional[int] = None,
         batch_size: int = cfg.BATCH_SIZE,
         require_strict_model_load: bool = False,
         restore_optimizer=True,
         optim_string='adam',
         lr=0.01) -> None:

    print("TRAINING MODEL {} ON DATASET {}".format(model_name, dset_name))

    ckpt_savedir = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.CKPT_DIR)
    print("Saving ckpts to {}".format(ckpt_savedir))
    logs_savepath = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.LOGDIR)
    print("Saving logs to {}".format(logs_savepath))
    utils.makedirs([ckpt_savedir, logs_savepath])
    last_ckpt_path = os.path.join(ckpt_savedir, "last_model.pth")

    device = utils.set_device()

    print('DEVICE', device)

    # model
    model = get_model(model_name, device)
    # print("model", model)
    model = DataParallel(model)

    # must call this before constructing the optimizer:
    # https://pytorch.org/docs/stable/optim.html
    model.to(device)

    # set up training
    # TODO better one?

    if optim_string == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optim_string == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=0.9,
                                    weight_decay=0.0001)
    else:
        raise RuntimeError(
            "Unrecognized optimizer string {}".format(optim_string))

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5,
                                                   gamma=0.1)
    # criterion = MemAlphaLoss(device=device)
    # criterion = MemMSELoss()
    # criterion = lambda x, y: MemMSELoss()(x, y) +
    # CaptionsLoss(device=device)(x, y)
    losses = {
        'mem_mse':
        MemMSELoss(device=device, weights=np.load("memento_weights.npy")),
        'captions':
        CaptionsLoss(device=device,
                     class_weights=cap_utils.get_vocab_weights())
    }

    initial_epoch = 0
    iteration = 0
    unfrozen = False

    if restore:
        ckpt_path = restore if isinstance(restore, str) else last_ckpt_path

        if os.path.exists(ckpt_path):

            print("Restoring weights from {}".format(ckpt_path))

            ckpt = torch.load(ckpt_path)
            utils.try_load_state_dict(model, ckpt['model_state_dict'],
                                      require_strict_model_load)

            if restore_optimizer:
                utils.try_load_optim_state(optimizer,
                                           ckpt['optimizer_state_dict'],
                                           require_strict_model_load)
            initial_epoch = ckpt['epoch']
            iteration = ckpt['it']
    else:
        ckpt_path = last_ckpt_path

    # dataset
    train_ds, val_ds, test_ds = get_dataset(dset_name)
    assert val_ds or test_ds

    if debug_n is not None:
        train_ds = Subset(train_ds, range(debug_n))
        test_ds = Subset(test_ds, range(debug_n))

    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=cfg.NUM_WORKERS)
    test_dl = DataLoader(test_ds,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=cfg.NUM_WORKERS)

    # training loop
    start = time.time()

    try:
        for epoch in range(initial_epoch, cfg.NUM_EPOCHS):
            logger = SummaryWriter(logs_savepath)

            # effectively puts the model in train mode.
            # Opposite of model.eval()
            model.train()

            print("Epoch {}".format(epoch))

            for i, (x, y_) in tqdm(enumerate(train_dl),
                                   total=len(train_ds) / batch_size):

                y: ModelOutput[MemModelFields] = ModelOutput(y_)
                iteration += 1

                if not unfrozen and iteration > freeze_until_it:
                    print("Unfreezing encoder")
                    unfrozen = True

                    for param in model.parameters():
                        param.requires_grad = True

                logger.add_scalar('DataTime', time.time() - start, iteration)

                x = x.to(device)
                y = y.to_device(device)

                out = ModelOutput(model(x, y.get_data()))
                loss_vals = {name: l(out, y) for name, l in losses.items()}
                # print("loss_vals", loss_vals)
                loss = torch.stack(list(loss_vals.values()))

                if verbose:
                    print("stacked loss", loss)
                loss = loss.sum()
                # loss = criterion(out, y)

                # I think this zeros out previous gradients (in case people
                # want to accumulate gradients?)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # logging
                utils.log_loss(logger, loss, loss_vals, iteration)
                logger.add_scalar('ItTime', time.time() - start, iteration)
                start = time.time()

                # display metrics

            # do some validation

            if (epoch + 1) % val_freq == 0:
                print("Validating...")
                model.eval()  # puts model in validation mode
                val_iteration = iteration

                with torch.no_grad():

                    labels: Optional[ModelOutput[MemModelFields]] = None
                    preds: Optional[ModelOutput[MemModelFields]] = None
                    val_losses = []

                    for i, (x, y_) in tqdm(enumerate(test_dl),
                                           total=len(test_ds) / batch_size):
                        val_iteration += 1

                        y = ModelOutput(y_)
                        y_numpy = y.to_numpy()

                        labels = y_numpy if labels is None else labels.merge(
                            y_numpy)

                        x = x.to(device)
                        y = y.to_device(device)

                        out = ModelOutput(model(x, y.get_data()))
                        out_numpy = out.to_device('cpu').to_numpy()
                        preds = out_numpy if preds is None else preds.merge(
                            out_numpy)

                        loss_vals = {
                            name: l(out, y)
                            for name, l in losses.items()
                        }
                        loss = torch.stack(list(loss_vals.values())).sum()
                        utils.log_loss(logger,
                                       loss,
                                       loss_vals,
                                       val_iteration,
                                       phase='val')

                        val_losses.append(loss)

                    print("Calculating validation metric...")
                    # print("preds", {k: v.shape for k, v in preds.items()})
                    # assert False
                    metrics = {
                        fname: f(labels, preds, losses)
                        for fname, f in additional_metrics.items()
                    }
                    print("Validation metrics", metrics)

                    for k, v in metrics.items():
                        if isinstance(v, numbers.Number):
                            logger.add_scalar('Metric_{}'.format(k), v,
                                              iteration)

                    metrics['total_val_loss'] = sum(val_losses)

                    ckpt_path = os.path.join(
                        ckpt_savedir, utils.get_ckpt_path(epoch, metrics))
                    save_ckpt(ckpt_path, model, epoch, iteration, optimizer,
                              dset_name, model_name, metrics)

            # end of epoch
            lr_scheduler.step()

            save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer,
                      dset_name, model_name)

    except KeyboardInterrupt:
        print('Got keyboard interrupt, saving model...')
        save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer,
                  dset_name, model_name)
예제 #6
0
    def __call__(self, sample: ModelOutput):
        data = {k: self.transform([v])[0] for k, v in sample.items()}

        return ModelOutput(data)
예제 #7
0
def predict(ckpt_path,
            metrics: Mapping[str, Callable] = {'rc': rc},
            num_workers: int = 20,
            use_gpu: bool = True,
            model_name: str = "resnet3d",
            dset_name: str = "memento_frames",
            batch_size: int = 1,
            preds_savepath: Optional[str] = None,
            use_val: bool = False,
            debug_n: Optional[int] = None,
            final_activation: str = 'relu',
            shuffle=False):

    print("ckpt path: {}".format(ckpt_path))

    if preds_savepath is None:
        preds_savepath = os.path.splitext(
            ckpt_path.replace(cfg.CKPT_DIR, cfg.PREDS_DIR))[0] + '.json'
        utils.makedirs([os.path.dirname(preds_savepath)])
    print("preds savepath: {}".format(preds_savepath))

    device = utils.set_device()
    print('DEVICE', device)

    # load the ckpt
    print("Loading model from path: {}".format(ckpt_path))
    ckpt = torch.load(ckpt_path)

    # model
    model = nn.DataParallel(MemRestNet3D(final_activation=final_activation))
    model.load_state_dict(ckpt['model_state_dict'], strict=True)

    model.to(device)
    model.eval()

    print("USING MODEL TYPE {} ON DSET {}".format(model_name, dset_name))

    # data loader
    train, val, test = get_dataset()
    ds = val if use_val else test

    if ds is None:
        raise ValueError("No {} set available for this dataset.".format(
            "val" if use_val else "test"))
    else:
        print("Using {} set".format("val" if use_val else "test"))

    if debug_n is not None:
        ds = Subset(ds, range(debug_n))

    dl = DataLoader(ds,
                    batch_size=batch_size,
                    shuffle=shuffle,
                    num_workers=num_workers)

    preds: Optional[ModelOutput] = None
    labels: Optional[ModelOutput] = None
    with torch.no_grad():
        for i, (x, y_) in tqdm(enumerate(dl), total=len(ds) / batch_size):

            y: ModelOutput[MemModelFields] = ModelOutput(y_)
            y_list = y.to_numpy()
            labels = y_list if labels is None else labels.merge(y_list)

            x = x.to(device)
            y = y.to_device(device)

            out = ModelOutput(model(x, y.get_data()))

            out_list = out.to_device('cpu').to_numpy()
            preds = out_list if preds is None else preds.merge(out_list)

    metrics = {fname: f(labels, preds, None) for fname, f in metrics.items()}
    print("METRICS", metrics)

    data = {
        'ckpt': ckpt_path,
        'preds': preds.to_list().get_data(),
        'labels': labels.to_list().get_data(),
        'metrics': metrics
    }

    with open(preds_savepath, "w") as outfile:
        print("Saving results")
        json.dump(data, outfile)

    return metrics