def create_models(info,
                  model_loadpath='',
                  dataset_name='FashionMNIST',
                  load_data=True):
    '''
    load details of previous model if applicable, otherwise create new models
    '''
    train_cnt = 0
    epoch_cnt = 0

    # use argparse device no matter what info dict is loaded
    preserve_args = [
        'device', 'batch_size', 'save_every_epochs', 'base_filepath',
        'model_loadpath', 'perplexity', 'use_pred'
    ]
    largs = info['args']
    # load model if given a path
    if model_loadpath != '':
        _dict = torch.load(model_loadpath,
                           map_location=lambda storage, loc: storage)
        dinfo = _dict['info']
        pkeys = info.keys()
        for key in dinfo.keys():
            if key not in preserve_args or key not in pkeys:
                info[key] = dinfo[key]
        train_cnt = info['train_cnts'][-1]
        epoch_cnt = info['epoch_cnt']
        info['args'].append(largs)

    if info['small']:
        info['frame_shrink_kernel_size'] = (4, 4)
        info['frame_shrink_trim_after'] = 0
        info['frame_shrink_trim_before'] = 2
        info['frame_height'] = 20
        info['frame_width'] = 20
    elif info['big']:
        # dont adjust replay buffer size
        info['frame_shrink_kernel_size'] = (0, 0)
        info['frame_shrink_trim_after'] = 0
        info['frame_shrink_trim_before'] = 0
        info['frame_height'] = 84
        info['frame_width'] = 84
    else:
        info['frame_shrink_kernel_size'] = (2, 2)
        info['frame_shrink_trim_after'] = 1
        info['frame_shrink_trim_before'] = 0
        info['frame_height'] = 40
        info['frame_width'] = 40

    # pixel cnn architecture is dependent on loss
    # for dml prediction, need to output mixture of size nmix
    info['nmix'] = (2 * info['nr_logistic_mix'] +
                    info['nr_logistic_mix']) * info['target_channels']
    info['output_dim'] = info['nmix']

    if load_data:
        data_dict, info = load_data_fn(info)
    else:
        data_dict = {}

    # setup models
    # acn prior with vqvae embedding
    mid_vq_acn_model = midACNVQVAEres(code_len=info['code_length'],
                                      input_size=info['input_channels'],
                                      output_size=info['output_dim'],
                                      hidden_size=info['hidden_size'],
                                      num_clusters=info['num_vqk'],
                                      num_z=info['num_z'],
                                      num_actions=info['num_actions'],
                                      num_rewards=info['num_rewards'],
                                      small=info['small'],
                                      big=info['big']).to(info['device'])

    prior_model = tPTPriorNetwork(size_training_set=info['size_training_set'],
                                  code_length=info['code_length'],
                                  k=info['num_k']).to(info['device'])
    prior_model.codes = prior_model.codes.to(info['device'])

    model_dict = {
        'mid_vq_acn_model': mid_vq_acn_model,
        'prior_model': prior_model
    }
    parameters = []
    for name, model in model_dict.items():
        parameters += list(model.parameters())
        print('created %s model with %s parameters' %
              (name, count_parameters(model)))

    model_dict['opt'] = optim.Adam(parameters, lr=info['learning_rate'])

    if model_loadpath != '':
        for name, model in model_dict.items():
            model_dict[name].load_state_dict(_dict[name + '_state_dict'])
    return model_dict, data_dict, info, train_cnt, epoch_cnt, rescale, rescale_inv
def create_models(info, model_loadpath='', dataset_name='FashionMNIST'):
    '''
    load details of previous model if applicable, otherwise create new models
    '''
    train_cnt = 0
    epoch_cnt = 0

    # use argparse device no matter what info dict is loaded
    preserve_args = ['device', 'batch_size', 'save_every_epochs',
                     'base_filepath', 'model_loadpath', 'perplexity',
                     'use_pred']
    largs = info['args']
    # load model if given a path
    if model_loadpath !='':
        _dict = torch.load(model_loadpath, map_location=lambda storage, loc:storage)
        dinfo = _dict['info']
        pkeys = info.keys()
        for key in dinfo.keys():
            if key not in preserve_args or key not in pkeys:
                info[key] = dinfo[key]
        train_cnt = info['train_cnts'][-1]
        epoch_cnt = info['epoch_cnt']
        info['args'].append(largs)

    # transform is dependent on loss type
    data_dict, data_paths = make_random_subset_buffers(dataset_path=info['base_datadir'],
                                           buffer_path=info['base_train_buffer_path'],
                                           train_max_examples=info['size_training_set'],
                                           kernel_size=info['frame_shrink_kernel_size'],
                                                       trim=info['frame_shrink_trim'])

    info['frame_height'] = data_dict['train'].frame_height
    info['frame_width'] = data_dict['train'].frame_width
    info['num_actions'] = data_dict['train'].num_actions()
    info['num_rewards'] = data_dict['train'].num_rewards()
    # assume actions/rewards are 0 indexed
    assert(min(data_dict['train'].rewards) == 0)
    assert(max(data_dict['train'].rewards) == info['num_rewards']-1)
    assert(min(data_dict['train'].actions) == 0)
    assert(max(data_dict['train'].actions) == info['num_actions']-1)

    # pixel cnn architecture is dependent on loss
    # for dml prediction, need to output mixture of size nmix
    info['nmix'] =  (2*info['nr_logistic_mix']+info['nr_logistic_mix'])*info['target_channels']
    info['output_dim']  = info['nmix']
    # last layer for pcnn - bias is 0 for dml
    info['last_layer_bias'] = 0.0

    # setup models
    # acn prior with vqvae embedding
    fwd_vq_acn_model = fwdACNVQVAEres(code_len=info['code_length'],
                               input_size=info['input_channels'],
                               output_size=info['output_dim'],
                               hidden_size=info['hidden_size'],
                               num_clusters=info['num_vqk'],
                               num_z=info['num_z'],
                               num_actions=info['num_actions'],
                               num_rewards=info['num_rewards'],
                               ).to(info['device'])

    prior_model = tPTPriorNetwork(size_training_set=info['size_training_set'],
                               code_length=info['code_length'], k=info['num_k']).to(info['device'])
    prior_model.codes = prior_model.codes.to(info['device'])

    pcnn_decoder = GatedPixelCNN(input_dim=info['target_channels'],
                                 output_dim=info['output_dim'],
                                 dim=info['pixel_cnn_dim'],
                                 n_layers=info['num_pcnn_layers'],
                                 # output dim is same as deconv output in this
                                 # case
                                 spatial_condition_size=info['output_dim'],
                                 last_layer_bias=info['last_layer_bias'],
                                 use_batch_norm=False,
                                 output_projection_size=info['output_projection_size']).to(info['device'])


    model_dict = {'fwd_vq_acn_model':fwd_vq_acn_model, 'prior_model':prior_model, 'pcnn_decoder_model':pcnn_decoder}
    parameters = []
    for name,model in model_dict.items():
        parameters+=list(model.parameters())
        print('created %s model with %s parameters' %(name,count_parameters(model)))

    model_dict['opt'] = optim.Adam(parameters, lr=info['learning_rate'])

    if args.model_loadpath !='':
       for name,model in model_dict.items():
            model_dict[name].load_state_dict(_dict[name+'_state_dict'])
    return model_dict, data_dict, info, train_cnt, epoch_cnt, rescale, rescale_inv
Ejemplo n.º 3
0
def create_models(info, model_loadpath=''):
    '''
    load details of previous model if applicable, otherwise create new models
    '''
    train_cnt = 0
    epoch_cnt = 0

    # use argparse device no matter what info dict is loaded
    preserve_args = [
        'device', 'batch_size', 'save_every_epochs', 'base_filepath',
        'model_loadpath', 'perplexity', 'num_examples_to_train'
    ]
    largs = info['args']
    # load model if given a path
    if model_loadpath != '':
        if os.path.exists(model_loadpath + '.cd'):
            model_loadpath = model_loadpath + '.cd'
        _dict = torch.load(model_loadpath,
                           map_location=lambda storage, loc: storage)
        dinfo = _dict['info']
        pkeys = info.keys()
        for key in dinfo.keys():
            if key not in preserve_args or key not in pkeys:
                info[key] = dinfo[key]
        train_cnt = info['train_cnts'][-1]
        epoch_cnt = info['epoch_cnt']
        info['args'].append(largs)
    if info['rec_loss_type'] == 'dml':
        # data going into dml should be bt -1 and 1
        rescale = lambda x: (x - 0.5) * 2.
        rescale_inv = lambda x: (0.5 * x) + 0.5
    if info['rec_loss_type'] == 'bce':
        rescale = lambda x: x
        rescale_inv = lambda x: x
    dataset_transforms = transforms.Compose([transforms.ToTensor(), rescale])
    data_output = create_mnist_datasets(dataset_name=info['dataset_name'],
                                        base_datadir=info['base_datadir'],
                                        batch_size=info['batch_size'],
                                        dataset_transforms=dataset_transforms)
    data_dict, size_training_set, num_input_chans, num_output_chans, hsize, wsize = data_output
    info['num_input_chans'] = num_input_chans
    info['num_output_chans'] = num_input_chans
    info['hsize'] = hsize
    info['wsize'] = wsize

    if not loaded:
        info['size_training_set'] = size_training_set

    # pixel cnn architecture is dependent on loss
    # for dml prediction, need to output mixture of size nmix
    info['nmix'] = (2 * info['nr_logistic_mix'] +
                    info['nr_logistic_mix']) * info['target_channels']
    info['output_dim'] = info['nmix']

    # setup models
    # acn prior with vqvae embedding
    if info['vq_decoder']:
        acn_model = ACNVQVAEresMNIST(
            code_len=info['code_length'],
            input_size=info['input_channels'],
            output_size=info['output_dim'],
            hidden_size=info['hidden_size'],
            num_clusters=info['num_vqk'],
            num_z=info['num_z'],
        ).to(info['device'])
    else:
        acn_model = ACNresMNIST(
            code_len=info['code_length'],
            input_size=info['input_channels'],
            output_size=info['output_dim'],
            hidden_size=info['hidden_size'],
        ).to(info['device'])

    prior_model = tPTPriorNetwork(size_training_set=info['size_training_set'],
                                  code_length=info['code_length'],
                                  k=info['num_k']).to(info['device'])

    model_dict = {'acn_model': acn_model, 'prior_model': prior_model}
    parameters = []
    for name, model in model_dict.items():
        parameters += list(model.parameters())
        print('created %s model with %s parameters' %
              (name, count_parameters(model)))

    model_dict['opt'] = optim.Adam(parameters, lr=info['learning_rate'])

    if model_loadpath != '':
        for name, model in model_dict.items():
            if '_model' in name:
                lname = name + '_state_dict'
                model_dict[name].load_state_dict(_dict[lname])
        same = (_dict['codes'] == model_dict['prior_model'].codes.cpu()
                ).sum().item()
        # make sure that loaded codes are the same
        #model_dict = set_codes_from_model(data_dict, model_dict, info)
        assert same == _dict['codes'].shape[0] * _dict['codes'].shape[1]
    return model_dict, data_dict, info, train_cnt, epoch_cnt, rescale, rescale_inv