def create_dqn_model_dict(ch, num_actions, model_dict={}):
    cl = 192  # todo load this
    rsize = cl * ch.cfg['REP']['num_prev_steps']
    model_dict['policy_net'] = EnsembleNet(
        n_ensemble=ch.cfg['DQN']['n_ensemble'],
        n_actions=num_actions,
        code_length=rsize,
        num_hidden=ch.cfg['DQN']['n_hidden'],
        dueling=ch.cfg['DQN']['dueling']).to(ch.device)

    model_dict['target_net'] = EnsembleNet(
        n_ensemble=ch.cfg['DQN']['n_ensemble'],
        n_actions=num_actions,
        code_length=rsize,
        num_hidden=ch.cfg['DQN']['n_hidden'],
        dueling=ch.cfg['DQN']['dueling']).to(ch.device)

    if ch.cfg['DQN']['prior']:
        print("using randomized prior")
        prior_net = EnsembleNet(n_ensemble=ch.cfg['DQN']['n_ensemble'],
                                n_actions=num_actions,
                                code_length=rsize,
                                num_hidden=ch.cfg['DQN']['n_hidden'],
                                dueling=ch.cfg['DQN']['dueling']).to(ch.device)

        model_dict['policy_net'] = NetWithPrior(model_dict['policy_net'],
                                                prior_net,
                                                ch.cfg['DQN']['prior_scale'])
        model_dict['target_net'] = NetWithPrior(model_dict['target_net'],
                                                prior_net,
                                                ch.cfg['DQN']['prior_scale'])

    model_dict['target_net'].load_state_dict(
        model_dict['policy_net'].state_dict())

    for name, model in model_dict.items():
        print('created %s model with %s parameters' %
              (name, count_parameters(model)))
        model.eval()

    model_dict['opt'] = optim.Adam(model_dict['policy_net'].parameters(),
                                   lr=ch.cfg['DQN']['adam_learning_rate'])
    return model_dict
def create_models(info, model_loadpath='', dataset_name='FashionMNIST'):
    '''
    load details of previous model if applicable, otherwise create new models
    '''
    train_cnt = 0
    epoch_cnt = 0

    # use argparse device no matter what info dict is loaded
    preserve_args = ['device', 'batch_size', 'save_every_epochs',
                     'base_filepath', 'model_loadpath', 'perplexity',
                     'use_pred']
    largs = info['args']
    # load model if given a path
    if model_loadpath !='':
        _dict = torch.load(model_loadpath, map_location=lambda storage, loc:storage)
        dinfo = _dict['info']
        pkeys = info.keys()
        for key in dinfo.keys():
            if key not in preserve_args or key not in pkeys:
                info[key] = dinfo[key]
        train_cnt = info['train_cnts'][-1]
        epoch_cnt = info['epoch_cnt']
        info['args'].append(largs)

    # transform is dependent on loss type
    data_dict, data_paths = make_random_subset_buffers(dataset_path=info['base_datadir'],
                                           buffer_path=info['base_train_buffer_path'],
                                           train_max_examples=info['size_training_set'],
                                           kernel_size=info['frame_shrink_kernel_size'],
                                                       trim=info['frame_shrink_trim'])

    info['frame_height'] = data_dict['train'].frame_height
    info['frame_width'] = data_dict['train'].frame_width
    info['num_actions'] = data_dict['train'].num_actions()
    info['num_rewards'] = data_dict['train'].num_rewards()
    # assume actions/rewards are 0 indexed
    assert(min(data_dict['train'].rewards) == 0)
    assert(max(data_dict['train'].rewards) == info['num_rewards']-1)
    assert(min(data_dict['train'].actions) == 0)
    assert(max(data_dict['train'].actions) == info['num_actions']-1)

    # pixel cnn architecture is dependent on loss
    # for dml prediction, need to output mixture of size nmix
    info['nmix'] =  (2*info['nr_logistic_mix']+info['nr_logistic_mix'])*info['target_channels']
    info['output_dim']  = info['nmix']
    # last layer for pcnn - bias is 0 for dml
    info['last_layer_bias'] = 0.0

    # setup models
    # acn prior with vqvae embedding
    fwd_vq_acn_model = fwdACNVQVAEres(code_len=info['code_length'],
                               input_size=info['input_channels'],
                               output_size=info['output_dim'],
                               hidden_size=info['hidden_size'],
                               num_clusters=info['num_vqk'],
                               num_z=info['num_z'],
                               num_actions=info['num_actions'],
                               num_rewards=info['num_rewards'],
                               ).to(info['device'])

    prior_model = tPTPriorNetwork(size_training_set=info['size_training_set'],
                               code_length=info['code_length'], k=info['num_k']).to(info['device'])
    prior_model.codes = prior_model.codes.to(info['device'])

    pcnn_decoder = GatedPixelCNN(input_dim=info['target_channels'],
                                 output_dim=info['output_dim'],
                                 dim=info['pixel_cnn_dim'],
                                 n_layers=info['num_pcnn_layers'],
                                 # output dim is same as deconv output in this
                                 # case
                                 spatial_condition_size=info['output_dim'],
                                 last_layer_bias=info['last_layer_bias'],
                                 use_batch_norm=False,
                                 output_projection_size=info['output_projection_size']).to(info['device'])


    model_dict = {'fwd_vq_acn_model':fwd_vq_acn_model, 'prior_model':prior_model, 'pcnn_decoder_model':pcnn_decoder}
    parameters = []
    for name,model in model_dict.items():
        parameters+=list(model.parameters())
        print('created %s model with %s parameters' %(name,count_parameters(model)))

    model_dict['opt'] = optim.Adam(parameters, lr=info['learning_rate'])

    if args.model_loadpath !='':
       for name,model in model_dict.items():
            model_dict[name].load_state_dict(_dict[name+'_state_dict'])
    return model_dict, data_dict, info, train_cnt, epoch_cnt, rescale, rescale_inv
def create_models(info,
                  model_loadpath='',
                  dataset_name='FashionMNIST',
                  load_data=True):
    '''
    load details of previous model if applicable, otherwise create new models
    '''
    train_cnt = 0
    epoch_cnt = 0

    # use argparse device no matter what info dict is loaded
    preserve_args = [
        'device', 'batch_size', 'save_every_epochs', 'base_filepath',
        'model_loadpath', 'perplexity', 'use_pred'
    ]
    largs = info['args']
    # load model if given a path
    if model_loadpath != '':
        _dict = torch.load(model_loadpath,
                           map_location=lambda storage, loc: storage)
        dinfo = _dict['info']
        pkeys = info.keys()
        for key in dinfo.keys():
            if key not in preserve_args or key not in pkeys:
                info[key] = dinfo[key]
        train_cnt = info['train_cnts'][-1]
        epoch_cnt = info['epoch_cnt']
        info['args'].append(largs)

    if info['small']:
        info['frame_shrink_kernel_size'] = (4, 4)
        info['frame_shrink_trim_after'] = 0
        info['frame_shrink_trim_before'] = 2
        info['frame_height'] = 20
        info['frame_width'] = 20
    elif info['big']:
        # dont adjust replay buffer size
        info['frame_shrink_kernel_size'] = (0, 0)
        info['frame_shrink_trim_after'] = 0
        info['frame_shrink_trim_before'] = 0
        info['frame_height'] = 84
        info['frame_width'] = 84
    else:
        info['frame_shrink_kernel_size'] = (2, 2)
        info['frame_shrink_trim_after'] = 1
        info['frame_shrink_trim_before'] = 0
        info['frame_height'] = 40
        info['frame_width'] = 40

    # pixel cnn architecture is dependent on loss
    # for dml prediction, need to output mixture of size nmix
    info['nmix'] = (2 * info['nr_logistic_mix'] +
                    info['nr_logistic_mix']) * info['target_channels']
    info['output_dim'] = info['nmix']

    if load_data:
        data_dict, info = load_data_fn(info)
    else:
        data_dict = {}

    # setup models
    # acn prior with vqvae embedding
    mid_vq_acn_model = midACNVQVAEres(code_len=info['code_length'],
                                      input_size=info['input_channels'],
                                      output_size=info['output_dim'],
                                      hidden_size=info['hidden_size'],
                                      num_clusters=info['num_vqk'],
                                      num_z=info['num_z'],
                                      num_actions=info['num_actions'],
                                      num_rewards=info['num_rewards'],
                                      small=info['small'],
                                      big=info['big']).to(info['device'])

    prior_model = tPTPriorNetwork(size_training_set=info['size_training_set'],
                                  code_length=info['code_length'],
                                  k=info['num_k']).to(info['device'])
    prior_model.codes = prior_model.codes.to(info['device'])

    model_dict = {
        'mid_vq_acn_model': mid_vq_acn_model,
        'prior_model': prior_model
    }
    parameters = []
    for name, model in model_dict.items():
        parameters += list(model.parameters())
        print('created %s model with %s parameters' %
              (name, count_parameters(model)))

    model_dict['opt'] = optim.Adam(parameters, lr=info['learning_rate'])

    if model_loadpath != '':
        for name, model in model_dict.items():
            model_dict[name].load_state_dict(_dict[name + '_state_dict'])
    return model_dict, data_dict, info, train_cnt, epoch_cnt, rescale, rescale_inv