示例#1
0
import torch.nn.functional as F
from models.model_helpers import ParamsIndexTracker
# from nn.maskgen_rnn import MaskGenerator
from nn.maskgen_topk import MaskGenerator
from nn.rnn_base import RNNBase
from optimize import log_pbar, log_tf_event
from optimizers.optim_helpers import (BatchManagerArgument, DefaultIndexer,
                                      OptimizerBatchManager, OptimizerParams,
                                      OptimizerStates, StatesSlicingWrapper)
from tqdm import tqdm
from utils import utils
from utils.result import ResultDict
from utils.torchviz import make_dot

C = utils.getCudaManager('default')
debug_sigint = utils.getSignalCatcher('SIGINT')
debug_sigstp = utils.getSignalCatcher('SIGTSTP')


class Optimizer(nn.Module):
    def __init__(self,
                 hidden_sz=20,
                 preproc_factor=10.0,
                 preproc=False,
                 n_layers=1,
                 rnn_cell='gru',
                 sb_mode='unified'):
        super().__init__()
        assert sb_mode in ['none', 'normal', 'unified']
        self.topk_mask_gen = MaskGenerator()
示例#2
0
import pdb
from collections import OrderedDict

import gin
import numpy as np
import torch
from loader.meta_dataset import MetaDataset, MetaMultiDataset, PseudoMetaDataset
from nn.model import Model
from torch.utils.tensorboard import SummaryWriter
from utils import utils
from utils.color import Color
from utils.result import ResultDict, ResultFrame
from utils.utils import Printer

C = utils.getCudaManager('default')
sig_1 = utils.getSignalCatcher('SIGINT')
sig_2 = utils.getSignalCatcher('SIGTSTP')


@gin.configurable
def loop(mode,
         outer_steps,
         inner_steps,
         log_steps,
         fig_epochs,
         inner_lr,
         log_mask=True,
         unroll_steps=None,
         meta_batchsize=0,
         sampler=None,
         epoch=1,