def run(args): """ Recordar cambiar los nhops y el w_instructions del dataset, y el metodo de embedding en make_batch tambien cambiar la normalizacion de la ultima instruccion""" logging.basicConfig(level=logging.INFO) pl.seed_everything(args.seed) model = NSMLightningModule.load_from_checkpoint(args.checkpoint_path) # breakpoint() datamodule = ClevrNoImagesDataModule("data", batch_size=1, w_instructions=False, nhops=[0]) # TODO: quiza cambiar esto a "validate"? por ahora solo estoy probando overfitear # asi que tiene sentido solo revisar ejemplos de train datamodule.setup("validate") dataset = datamodule.clevr_val example = dataset.get_raw(random.randrange(len(dataset))) # attn_hook = ExtractAttnHook() # node_probs = ExtractProbsHook() # handle1 = model.nsm.instructions_model.softmax.register_forward_hook(attn_hook) # handle2 = model.nsm.nsm_cell.register_forward_hook(node_probs) # out = model(*make_batch(example, dataset.vocab)[:-1]) # handle1.remove() # handle2.remove() # q_attns = attn_hook.value.squeeze(0) # remove first dim (batch dim) plt.ion() DataShell(model, dataset).cmdloop()
def main(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) pl.seed_everything(seed=123, workers=True) logging.basicConfig(level=logging.INFO) datamodule = clevr.ClevrWInstructionsDataModule( "data", batch_size=16, nhops=[0] ) # most params obtained via inspection of dataset model = NSMLightningModule( input_size=45, n_node_properties=4, computation_steps=0 + 1, # lul encoded_question_size=100, output_size=28, learn_rate=0.001, use_instruction_loss=False, ) metric_to_track = "train_loss" trainer = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=1000, callbacks=[ pl.callbacks.EarlyStopping(monitor=metric_to_track, patience=300), pl.callbacks.ModelCheckpoint(save_top_k=-1,every_n_epochs=50), ], ) trainer.fit(model, datamodule)
def main_freeze_instructions(args): EMBEDDING_SIZE = 45 N_INSTRUCTIONS = 2 ENCODED_QUESTION_SIZE = 100 # Train instructions model only datamodule = ClevrWInstructionsDataModule("data", batch_size=128, nhops=[0]) instructions_model = InstructionsModelLightningModule( embedding_size=EMBEDDING_SIZE, n_instructions=N_INSTRUCTIONS, encoded_question_size=ENCODED_QUESTION_SIZE, learn_rate=0.001, ) trainer_instructions = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=100, callbacks=[ pl.callbacks.EarlyStopping(monitor="val_loss", patience=10), pl.callbacks.ModelCheckpoint(monitor="val_loss"), ], ) trainer_instructions.fit(instructions_model, datamodule) # Now train NSM, but load instructions model from previous step, freeze weights # and set to eval (to disable dropout, not relevant now but to keep in mind) datamodule = ClevrWInstructionsDataModule( "data", batch_size=args.batch_size, nhops=[0] ) nsm_model = NSMLightningModule( input_size=EMBEDDING_SIZE, n_node_properties=4, computation_steps=N_INSTRUCTIONS - 1, encoded_question_size=ENCODED_QUESTION_SIZE, output_size=28, learn_rate=args.learn_rate, use_instruction_loss=False, ) nsm_model.nsm.instructions_model.load_state_dict( instructions_model.model.state_dict() ) if not args.fine_tune: nsm_model.nsm.instructions_model.requires_grad_(False).eval() trainer_nsm = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=1000, callbacks=[ pl.callbacks.EarlyStopping(monitor="train_loss", patience=100), pl.callbacks.ModelCheckpoint(monitor="train_loss"), ], ) trainer_nsm.fit(nsm_model, datamodule)
def main(args): logging.basicConfig(level=logging.INFO) print(args) datamodule = cclevr.ComparisonClevrDataModule( datadir="data", batch_size=args.batch_size, subset_ratio=args.subset_ratio, ) # most params obtained via inspection of dataset if args.model_type == "NSM": model = NSMLightningModule( input_size=45, n_node_properties=4, computation_steps=args.computation_steps, encoded_question_size=args.encoded_size, output_size=28, learn_rate=args.learn_rate, # use_instruction_loss=False, ) elif args.model_type == "NSMBaseline": print(f"NSMBaseline chosen, ignoring {args.computation_steps=}") model = NSMBaselineLightningModule( input_size=45, n_node_properties=4, encoded_question_size=args.encoded_size, output_size=28, learn_rate=args.learn_rate, # use_instruction_loss=False, ) metric_to_track = "val_acc" trainer = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=500, callbacks=[ pl.callbacks.EarlyStopping( monitor=metric_to_track, patience=20, stopping_threshold=0.95, mode="max", ), pl.callbacks.ModelCheckpoint(monitor=metric_to_track), ], ) trainer.fit(model, datamodule)
def main(args): model = NSMLightningModule.load_from_checkpoint(args.checkpoint) datamodule = ClevrNoImagesDataModule("data", batch_size=128) datamodule.setup("fit") dataloader = datamodule.train_dataloader() question_cats = [ q["question_family_index"] for q in datamodule.clevr_val.questions ] cat_totals = Counter(question_cats) predictions = [] targets = [] for batch in tqdm.tqdm(dataloader, desc="Validating"): *first, target = batch predictions += model(*first).argmax(1).tolist() targets += target.tolist() cat_counts = defaultdict(int) for pred, target, cat in zip(predictions, targets, question_cats): cat_counts[cat] += pred == target cat_grid = fix_cat_grid(np.array(CATS).reshape(4, -1)).astype(float) for i in range(cat_grid.shape[0]): for j in range(cat_grid.shape[1]): cat = cat_grid[i, j] cat_grid[i, j] = cat_counts[cat] / cat_totals[cat] plt.imshow(cat_grid) # plot numbers for i, j in product(*map(range, cat_grid.shape)): plt.text(j, i, f"{cat_grid[i, j]:.2f}", ha="center", va="center", color="w") plt.xticks(ticks=np.arange(len(X_LABELS)), labels=X_LABELS) plt.yticks(ticks=np.arange(len(Y_LABELS)), labels=Y_LABELS) plt.colorbar() plt.title(Path(args.checkpoint).parts[-3]) plt.show()
def get_model(version, epoch): ckpt = next( p for p in Path(f"remote_logs/version_{version}/checkpoints").iterdir() if p.name.startswith(f"epoch={epoch}")) return NSMLightningModule.load_from_checkpoint(ckpt)
from nsm.model import NSMLightningModule from nsm.datasets import ClevrWInstructionsDataModule import logging import argparse import tqdm from pathlib import Path logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--batch-size", required=True, type=int) args = parser.parse_args() datamodule = ClevrWInstructionsDataModule("data", args.batch_size, nhops=[0]) model = NSMLightningModule.load_from_checkpoint(args.checkpoint) datamodule.setup("validate") correct = [] for *inputs, targets, gold_instructions in tqdm.tqdm( datamodule.val_dataloader()): predictions, generated_instructions = model(*inputs) correct += predictions.argmax(1).eq(targets).tolist() version = next( (p for p in Path(args.checkpoint).parts if p.startswith("version")), None) with open(f"val_results_{version}.txt", "w") as outfile: for val in correct: outfile.write(str(val) + "\n")
import logging import collections from pathlib import Path import tqdm import matplotlib.pyplot as plt import numpy as np logging.basicConfig(level=logging.INFO) VERSIONS = [74, 75, 76, 77, 78, 82] CHECKPOINTS = [ next(Path(f"remote_logs/version_{version}/checkpoints").iterdir()) for version in VERSIONS ] models = [ NSMLightningModule.load_from_checkpoint(ckpt) for ckpt in CHECKPOINTS ] datamodule = clevr.ClevrWInstructionsDataModule("data", batch_size=256, nhops=[0]) datamodule.setup("validate") gold = [] results = collections.defaultdict(list) for *inputs, targets, gold_instructins in tqdm.tqdm( datamodule.val_dataloader(), desc=f"Validating"): gold += targets.tolist() for i, model in enumerate(models): predictions, _ = model(*inputs)
def main_freeze_automaton(args): # Train automaton with my instructions EMBEDDING_SIZE = 45 N_INSTRUCTIONS = 2 ENCODED_SIZE = 100 class NSMDummyInstructions(NSMLightningModule): def __init__(self, **kwargs): super().__init__(**kwargs) self.nsm.instructions_model = DummyInstructionsModel( embedding_size=kwargs["input_size"], n_instructions=kwargs["computation_steps"] + 1, encoded_question_size=kwargs["encoded_question_size"], ) @staticmethod def _prep(batch): batch = list(batch) batch[1] = torch.nn.utils.rnn.pack_sequence(batch[5].unbind(0)) return tuple(batch) def training_step(self, batch, batch_idx): return super().training_step(self._prep(batch), batch_idx) def validation_step(self, batch, batch_idx): return super().validation_step(self._prep(batch), batch_idx) datamodule = ClevrWInstructionsDataModule("data", batch_size=32, nhops=[0]) nsm_dummy_ins = NSMDummyInstructions( input_size=EMBEDDING_SIZE, n_node_properties=4, computation_steps=N_INSTRUCTIONS - 1, encoded_question_size=ENCODED_SIZE, output_size=28, learn_rate=0.001, use_instruction_loss=False, ) trainer_automaton = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=10, callbacks=[ pl.callbacks.EarlyStopping(monitor="train_loss", patience=3), pl.callbacks.ModelCheckpoint(monitor="train_loss"), ], ) trainer_automaton.fit(nsm_dummy_ins, datamodule) # lead trained weight from the automaton to the nsm datamodule = ClevrWInstructionsDataModule( "data", batch_size=args.batch_size, nhops=[0] ) nsm_model = NSMLightningModule( input_size=EMBEDDING_SIZE, n_node_properties=4, computation_steps=N_INSTRUCTIONS - 1, encoded_question_size=ENCODED_SIZE, output_size=28, learn_rate=args.learn_rate, use_instruction_loss=False, ) nsm_model.nsm.nsm_cell.load_state_dict(nsm_dummy_ins.nsm.nsm_cell.state_dict()) nsm_model.nsm.classifier.load_state_dict(nsm_dummy_ins.nsm.classifier.state_dict()) if not args.fine_tune: nsm_model.nsm.nsm_cell.requires_grad_(False).eval() nsm_model.nsm.classifier.requires_grad_(False).eval() trainer_nsm = pl.Trainer( gpus=-1 if torch.cuda.is_available() else 0, max_epochs=1000, callbacks=[ pl.callbacks.EarlyStopping(monitor="train_loss", patience=200), pl.callbacks.ModelCheckpoint(monitor="train_loss"), ], ) trainer_nsm.fit(nsm_model, datamodule)
import argparse import yaml parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) args = parser.parse_args() hparams = yaml.load(open(pathlib.Path(args.checkpoint + "/hparams.yaml")), Loader=yaml.Loader) datamodule = ClevrWInstructionsDataModule( "data", batch_size=1, nhops=hparams["nhops"], prop_embeds_const=hparams["prop_embeds_const"], ) model = NSMLightningModule.load_from_checkpoint( next(pathlib.Path(args.checkpoint + "/checkpoints").iterdir())) datamodule.setup("validate") all_outs = [] all_tgts = [] for i, (*inputs, targets, gold_ins) in enumerate(tqdm.tqdm(datamodule.val_dataloader())): preds, gen_ins = model(*inputs) all_outs += preds.argmax(1).tolist() all_tgts += targets.tolist() version = next(p for p in pathlib.Path(args.checkpoint).parts if p.startswith("version")) breakpoint() with open(f"val_results_{version}.txt", "w") as out_f: for o, t in zip(all_outs, all_tgts):