def run(args):
    """ Recordar cambiar los nhops y el w_instructions del dataset, 
    y el metodo de embedding en make_batch
    tambien cambiar la normalizacion de la ultima instruccion"""
    logging.basicConfig(level=logging.INFO)
    pl.seed_everything(args.seed)
    model = NSMLightningModule.load_from_checkpoint(args.checkpoint_path)
    # breakpoint()
    datamodule = ClevrNoImagesDataModule("data", batch_size=1, w_instructions=False, nhops=[0])
    # TODO: quiza cambiar esto a "validate"? por ahora solo estoy probando overfitear
    # asi que tiene sentido solo revisar ejemplos de train
    datamodule.setup("validate")
    dataset = datamodule.clevr_val
    example = dataset.get_raw(random.randrange(len(dataset)))

    # attn_hook = ExtractAttnHook()
    # node_probs = ExtractProbsHook()

    # handle1 = model.nsm.instructions_model.softmax.register_forward_hook(attn_hook)
    # handle2 = model.nsm.nsm_cell.register_forward_hook(node_probs)
    # out = model(*make_batch(example, dataset.vocab)[:-1])
    # handle1.remove()
    # handle2.remove()
    # q_attns = attn_hook.value.squeeze(0)  # remove first dim (batch dim)

    plt.ion()
    DataShell(model, dataset).cmdloop()
def main():
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    pl.seed_everything(seed=123, workers=True)
    logging.basicConfig(level=logging.INFO)

    datamodule = clevr.ClevrWInstructionsDataModule(
        "data", batch_size=16, nhops=[0]
    )
    # most params obtained via inspection of dataset
    model = NSMLightningModule(
        input_size=45,
        n_node_properties=4,
        computation_steps=0 + 1,  # lul
        encoded_question_size=100,
        output_size=28,
        learn_rate=0.001,
        use_instruction_loss=False,
    )
    metric_to_track = "train_loss"
    trainer = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=1000,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor=metric_to_track, patience=300),
            pl.callbacks.ModelCheckpoint(save_top_k=-1,every_n_epochs=50),
        ],
    )
    trainer.fit(model, datamodule)
Example #3
0
def main_freeze_instructions(args):
    EMBEDDING_SIZE = 45
    N_INSTRUCTIONS = 2
    ENCODED_QUESTION_SIZE = 100
    # Train instructions model only
    datamodule = ClevrWInstructionsDataModule("data", batch_size=128, nhops=[0])
    instructions_model = InstructionsModelLightningModule(
        embedding_size=EMBEDDING_SIZE,
        n_instructions=N_INSTRUCTIONS,
        encoded_question_size=ENCODED_QUESTION_SIZE,
        learn_rate=0.001,
    )
    trainer_instructions = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=100,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="val_loss", patience=10),
            pl.callbacks.ModelCheckpoint(monitor="val_loss"),
        ],
    )
    trainer_instructions.fit(instructions_model, datamodule)

    # Now train NSM, but load instructions model from previous step, freeze weights
    # and set to eval (to disable dropout, not relevant now but to keep in mind)
    datamodule = ClevrWInstructionsDataModule(
        "data", batch_size=args.batch_size, nhops=[0]
    )
    nsm_model = NSMLightningModule(
        input_size=EMBEDDING_SIZE,
        n_node_properties=4,
        computation_steps=N_INSTRUCTIONS - 1,
        encoded_question_size=ENCODED_QUESTION_SIZE,
        output_size=28,
        learn_rate=args.learn_rate,
        use_instruction_loss=False,
    )

    nsm_model.nsm.instructions_model.load_state_dict(
        instructions_model.model.state_dict()
    )
    if not args.fine_tune:
        nsm_model.nsm.instructions_model.requires_grad_(False).eval()

    trainer_nsm = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=1000,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="train_loss", patience=100),
            pl.callbacks.ModelCheckpoint(monitor="train_loss"),
        ],
    )
    trainer_nsm.fit(nsm_model, datamodule)
Example #4
0
def main(args):
    logging.basicConfig(level=logging.INFO)

    print(args)
    datamodule = cclevr.ComparisonClevrDataModule(
        datadir="data",
        batch_size=args.batch_size,
        subset_ratio=args.subset_ratio,
    )
    
    # most params obtained via inspection of dataset
    if args.model_type == "NSM":
        model = NSMLightningModule(
            input_size=45,
            n_node_properties=4,
            computation_steps=args.computation_steps,
            encoded_question_size=args.encoded_size,
            output_size=28,
            learn_rate=args.learn_rate,
            # use_instruction_loss=False,
        )
    elif args.model_type == "NSMBaseline":
        print(f"NSMBaseline chosen, ignoring {args.computation_steps=}")
        model = NSMBaselineLightningModule(
            input_size=45,
            n_node_properties=4,
            encoded_question_size=args.encoded_size,
            output_size=28,
            learn_rate=args.learn_rate,
            # use_instruction_loss=False,
        )
    metric_to_track = "val_acc"
    trainer = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=500,
        callbacks=[
            pl.callbacks.EarlyStopping(
                monitor=metric_to_track,
                patience=20,
                stopping_threshold=0.95,
                mode="max",
            ),
            pl.callbacks.ModelCheckpoint(monitor=metric_to_track),
        ],
    )
    trainer.fit(model, datamodule)
Example #5
0
def main(args):
    model = NSMLightningModule.load_from_checkpoint(args.checkpoint)
    datamodule = ClevrNoImagesDataModule("data", batch_size=128)
    datamodule.setup("fit")
    dataloader = datamodule.train_dataloader()
    question_cats = [
        q["question_family_index"] for q in datamodule.clevr_val.questions
    ]
    cat_totals = Counter(question_cats)

    predictions = []
    targets = []
    for batch in tqdm.tqdm(dataloader, desc="Validating"):
        *first, target = batch
        predictions += model(*first).argmax(1).tolist()
        targets += target.tolist()

    cat_counts = defaultdict(int)
    for pred, target, cat in zip(predictions, targets, question_cats):
        cat_counts[cat] += pred == target

    cat_grid = fix_cat_grid(np.array(CATS).reshape(4, -1)).astype(float)
    for i in range(cat_grid.shape[0]):
        for j in range(cat_grid.shape[1]):
            cat = cat_grid[i, j]
            cat_grid[i, j] = cat_counts[cat] / cat_totals[cat]

    plt.imshow(cat_grid)
    # plot numbers
    for i, j in product(*map(range, cat_grid.shape)):
        plt.text(j,
                 i,
                 f"{cat_grid[i, j]:.2f}",
                 ha="center",
                 va="center",
                 color="w")
    plt.xticks(ticks=np.arange(len(X_LABELS)), labels=X_LABELS)
    plt.yticks(ticks=np.arange(len(Y_LABELS)), labels=Y_LABELS)
    plt.colorbar()
    plt.title(Path(args.checkpoint).parts[-3])
    plt.show()
def get_model(version, epoch):
    ckpt = next(
        p
        for p in Path(f"remote_logs/version_{version}/checkpoints").iterdir()
        if p.name.startswith(f"epoch={epoch}"))
    return NSMLightningModule.load_from_checkpoint(ckpt)
from nsm.model import NSMLightningModule
from nsm.datasets import ClevrWInstructionsDataModule
import logging
import argparse
import tqdm
from pathlib import Path

logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--batch-size", required=True, type=int)
args = parser.parse_args()

datamodule = ClevrWInstructionsDataModule("data", args.batch_size, nhops=[0])
model = NSMLightningModule.load_from_checkpoint(args.checkpoint)

datamodule.setup("validate")

correct = []
for *inputs, targets, gold_instructions in tqdm.tqdm(
        datamodule.val_dataloader()):
    predictions, generated_instructions = model(*inputs)
    correct += predictions.argmax(1).eq(targets).tolist()

version = next(
    (p for p in Path(args.checkpoint).parts if p.startswith("version")), None)
with open(f"val_results_{version}.txt", "w") as outfile:
    for val in correct:
        outfile.write(str(val) + "\n")
Example #8
0
import logging
import collections
from pathlib import Path
import tqdm
import matplotlib.pyplot as plt
import numpy as np

logging.basicConfig(level=logging.INFO)

VERSIONS = [74, 75, 76, 77, 78, 82]
CHECKPOINTS = [
    next(Path(f"remote_logs/version_{version}/checkpoints").iterdir())
    for version in VERSIONS
]
models = [
    NSMLightningModule.load_from_checkpoint(ckpt) for ckpt in CHECKPOINTS
]

datamodule = clevr.ClevrWInstructionsDataModule("data",
                                                batch_size=256,
                                                nhops=[0])
datamodule.setup("validate")

gold = []
results = collections.defaultdict(list)

for *inputs, targets, gold_instructins in tqdm.tqdm(
        datamodule.val_dataloader(), desc=f"Validating"):
    gold += targets.tolist()
    for i, model in enumerate(models):
        predictions, _ = model(*inputs)
Example #9
0
def main_freeze_automaton(args):
    # Train automaton with my instructions
    EMBEDDING_SIZE = 45
    N_INSTRUCTIONS = 2
    ENCODED_SIZE = 100

    class NSMDummyInstructions(NSMLightningModule):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.nsm.instructions_model = DummyInstructionsModel(
                embedding_size=kwargs["input_size"],
                n_instructions=kwargs["computation_steps"] + 1,
                encoded_question_size=kwargs["encoded_question_size"],
            )

        @staticmethod
        def _prep(batch):
            batch = list(batch)
            batch[1] = torch.nn.utils.rnn.pack_sequence(batch[5].unbind(0))
            return tuple(batch)

        def training_step(self, batch, batch_idx):
            return super().training_step(self._prep(batch), batch_idx)

        def validation_step(self, batch, batch_idx):
            return super().validation_step(self._prep(batch), batch_idx)

    datamodule = ClevrWInstructionsDataModule("data", batch_size=32, nhops=[0])
    nsm_dummy_ins = NSMDummyInstructions(
        input_size=EMBEDDING_SIZE,
        n_node_properties=4,
        computation_steps=N_INSTRUCTIONS - 1,
        encoded_question_size=ENCODED_SIZE,
        output_size=28,
        learn_rate=0.001,
        use_instruction_loss=False,
    )
    trainer_automaton = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=10,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="train_loss", patience=3),
            pl.callbacks.ModelCheckpoint(monitor="train_loss"),
        ],
    )
    trainer_automaton.fit(nsm_dummy_ins, datamodule)

    # lead trained weight from the automaton to the nsm
    datamodule = ClevrWInstructionsDataModule(
        "data", batch_size=args.batch_size, nhops=[0]
    )
    nsm_model = NSMLightningModule(
        input_size=EMBEDDING_SIZE,
        n_node_properties=4,
        computation_steps=N_INSTRUCTIONS - 1,
        encoded_question_size=ENCODED_SIZE,
        output_size=28,
        learn_rate=args.learn_rate,
        use_instruction_loss=False,
    )
    nsm_model.nsm.nsm_cell.load_state_dict(nsm_dummy_ins.nsm.nsm_cell.state_dict())
    nsm_model.nsm.classifier.load_state_dict(nsm_dummy_ins.nsm.classifier.state_dict())
    if not args.fine_tune:
        nsm_model.nsm.nsm_cell.requires_grad_(False).eval()
        nsm_model.nsm.classifier.requires_grad_(False).eval()

    trainer_nsm = pl.Trainer(
        gpus=-1 if torch.cuda.is_available() else 0,
        max_epochs=1000,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="train_loss", patience=200),
            pl.callbacks.ModelCheckpoint(monitor="train_loss"),
        ],
    )
    trainer_nsm.fit(nsm_model, datamodule)
import argparse
import yaml

parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
args = parser.parse_args()

hparams = yaml.load(open(pathlib.Path(args.checkpoint + "/hparams.yaml")),
                    Loader=yaml.Loader)
datamodule = ClevrWInstructionsDataModule(
    "data",
    batch_size=1,
    nhops=hparams["nhops"],
    prop_embeds_const=hparams["prop_embeds_const"],
)
model = NSMLightningModule.load_from_checkpoint(
    next(pathlib.Path(args.checkpoint + "/checkpoints").iterdir()))

datamodule.setup("validate")

all_outs = []
all_tgts = []
for i, (*inputs, targets,
        gold_ins) in enumerate(tqdm.tqdm(datamodule.val_dataloader())):
    preds, gen_ins = model(*inputs)
    all_outs += preds.argmax(1).tolist()
    all_tgts += targets.tolist()
version = next(p for p in pathlib.Path(args.checkpoint).parts
               if p.startswith("version"))
breakpoint()
with open(f"val_results_{version}.txt", "w") as out_f:
    for o, t in zip(all_outs, all_tgts):