import torch.nn.functional as F from models.model_helpers import ParamsIndexTracker # from nn.maskgen_rnn import MaskGenerator from nn.maskgen_topk import MaskGenerator from nn.rnn_base import RNNBase from optimize import log_pbar, log_tf_event from optimizers.optim_helpers import (BatchManagerArgument, DefaultIndexer, OptimizerBatchManager, OptimizerParams, OptimizerStates, StatesSlicingWrapper) from tqdm import tqdm from utils import utils from utils.result import ResultDict from utils.torchviz import make_dot C = utils.getCudaManager('default') debug_sigint = utils.getSignalCatcher('SIGINT') debug_sigstp = utils.getSignalCatcher('SIGTSTP') class Optimizer(nn.Module): def __init__(self, hidden_sz=20, preproc_factor=10.0, preproc=False, n_layers=1, rnn_cell='gru', sb_mode='unified'): super().__init__() assert sb_mode in ['none', 'normal', 'unified'] self.topk_mask_gen = MaskGenerator()
import pdb from collections import OrderedDict import gin import numpy as np import torch from loader.meta_dataset import MetaDataset, MetaMultiDataset, PseudoMetaDataset from nn.model import Model from torch.utils.tensorboard import SummaryWriter from utils import utils from utils.color import Color from utils.result import ResultDict, ResultFrame from utils.utils import Printer C = utils.getCudaManager('default') sig_1 = utils.getSignalCatcher('SIGINT') sig_2 = utils.getSignalCatcher('SIGTSTP') @gin.configurable def loop(mode, outer_steps, inner_steps, log_steps, fig_epochs, inner_lr, log_mask=True, unroll_steps=None, meta_batchsize=0, sampler=None, epoch=1,