'test_losses': [],
        'test_kl_losses': [],
        'test_rec_losses': [],
        'save_times': [],
        'args': [args],
        'last_save': 0,
        'last_plot': 0,
    }

    args.size_training_set = len(train_data_function)
    hsize = data_loader.train_loader.data.shape[1]
    wsize = data_loader.train_loader.data.shape[2]

    encoder_model = ConvVAE(
        args.code_length,
        input_size=args.number_condition,
        encoder_output_size=args.encoder_output_size,
    ).to(DEVICE)
    prior_model = PriorNetwork(
        size_training_set=args.size_training_set,
        code_length=args.code_length,
        n_mixtures=args.num_mixtures,
        k=args.num_k,
        require_unique_codes=args.require_unique_codes).to(DEVICE)

    pcnn_decoder = GatedPixelCNN(input_dim=1,
                                 dim=args.possible_values,
                                 n_layers=args.num_pcnn_layers,
                                 n_classes=args.num_classes,
                                 float_condition_size=args.code_length,
                                 last_layer_bias=0.5,

    args.size_training_set = len(train_data_function)
    hsize = data_loader.train_loader.data.shape[1]
    wsize = data_loader.train_loader.data.shape[2]

    info = model_dict['info']
    largs = info['args'][-1]

    try:
        print(largs.encoder_output_size)
    except:
        largs.encoder_output_size = 1000

    encoder_model = ConvVAE(largs.code_length,
                            input_size=largs.number_condition,
                            encoder_output_size=largs.encoder_output_size)
    encoder_model.load_state_dict(model_dict['vae_state_dict'])

    prior_model = PriorNetwork(size_training_set=largs.size_training_set,
                                code_length=largs.code_length,
                                k=largs.num_k)
    prior_model.codes = model_dict['codes']

    pcnn_decoder = GatedPixelCNN(input_dim=1,
                                 dim=largs.possible_values,
                                 n_layers=largs.num_pcnn_layers,
                                 n_classes=largs.num_classes,
                                 float_condition_size=largs.code_length,
                                 last_layer_bias=0.5, hsize=hsize, wsize=wsize)
    info = {
        'train_cnts': [],
        'train_losses': [],
        'test_cnts': [],
        'test_losses': [],
        'save_times': [],
        'args': [args],
        'last_save': 0,
        'last_plot': 0,
    }

    args.size_training_set = len(train_data)
    nchans, hsize, wsize = test_loader.dataset[0][0].shape
    encoder_model = ConvVAE(
        args.code_length,
        input_size=1,
        encoder_output_size=args.encoder_output_size).to(DEVICE)
    prior_model = PriorNetwork(
        size_training_set=args.size_training_set,
        code_length=args.code_length,
        n_mixtures=args.num_mixtures,
        k=args.num_k,
        require_unique_codes=args.require_unique_codes).to(DEVICE)

    pcnn_decoder = GatedPixelCNN(input_dim=1,
                                 dim=args.possible_values,
                                 n_layers=args.num_pcnn_layers,
                                 n_classes=args.num_classes,
                                 float_condition_size=args.code_length,
                                 last_layer_bias=0.5,
                                 hsize=hsize,
def init_train():
    """ use args to setup inplace training """
    train_data_path = args.train_buffer
    valid_data_path = args.valid_buffer

    data_dir = os.path.split(train_data_path)[0]

    # we are starting from scratch training this model
    if args.model_loadpath == "":
        run_num = 0
        model_base_filedir = os.path.join(data_dir,
                                          args.savename + '%02d' % run_num)
        while os.path.exists(model_base_filedir):
            run_num += 1
            model_base_filedir = os.path.join(data_dir,
                                              args.savename + '%02d' % run_num)
        os.makedirs(model_base_filedir)
        model_base_filepath = os.path.join(model_base_filedir, args.savename)
        print("MODEL BASE FILEPATH", model_base_filepath)

        info = {
            'model_train_cnts': [],
            'model_train_losses': {},
            'model_valid_cnts': [],
            'model_valid_losses': {},
            'model_save_times': [],
            'model_last_save': 0,
            'model_last_plot': 0,
            'NORM_BY': 255.0,
            'MODEL_BASE_FILEDIR': model_base_filedir,
            'model_base_filepath': model_base_filepath,
            'model_train_data_file': train_data_path,
            'model_valid_data_file': valid_data_path,
            'NUM_TRAINING_EXAMPLES': args.num_training_examples,
            'NUM_K': args.num_k,
            'NR_LOGISTIC_MIX': args.nr_logistic_mix,
            'NUM_PCNN_FILTERS': args.num_pcnn_filters,
            'NUM_PCNN_LAYERS': args.num_pcnn_layers,
            'ALPHA_REC': args.alpha_rec,
            'ALPHA_ACT': args.alpha_act,
            'ALPHA_REW': args.alpha_rew,
            'MODEL_BATCH_SIZE': args.batch_size,
            'NUMBER_CONDITION': args.num_condition,
            'CODE_LENGTH': args.code_length,
            'NUM_MIXTURES': args.num_mixtures,
            'REQUIRE_UNIQUE_CODES': args.require_unique_codes,
        }

        ## size of latents flattened - dependent on architecture
        #info['float_condition_size'] = 100*args.num_z
        ## 3x logistic needed for loss
        ## TODO - change loss
    else:
        print('loading model from: %s' % args.model_loadpath)
        model_dict = torch.load(args.model_loadpath,
                                map_location=lambda storage, loc: storage)
        info = model_dict['model_info']
        model_base_filedir = os.path.split(args.model_loadpath)[0]
        model_base_filepath = os.path.join(model_base_filedir, args.savename)
        info['loaded_from'] = args.model_loadpath
        info['MODEL_BATCH_SIZE'] = args.batch_size
    info['DEVICE'] = DEVICE
    info['MODEL_SAVE_EVERY'] = args.save_every
    info['MODEL_LOG_EVERY_BATCHES'] = args.log_every_batches
    info['model_loadpath'] = args.model_loadpath
    info['MODEL_SAVENAME'] = args.savename
    info['MODEL_LEARNING_RATE'] = args.learning_rate
    # create replay buffer
    train_buffer = make_subset_buffer(
        train_data_path, max_examples=info['NUM_TRAINING_EXAMPLES'])
    valid_buffer = make_subset_buffer(valid_data_path,
                                      max_examples=int(
                                          info['NUM_TRAINING_EXAMPLES'] * .1))
    #valid_buffer = ReplayMemory(load_file=valid_data_path)
    # if train buffer is too large - make random subset
    # 27588 places in 1e6 buffer where reward is nonzero

    info['num_actions'] = train_buffer.num_actions()
    info['size_training_set'] = train_buffer.num_examples()
    info['hsize'] = train_buffer.frame_height
    info['wsize'] = train_buffer.frame_width
    info['num_rewards'] = train_buffer.num_rewards()
    info['HISTORY_SIZE'] = 4

    rewards_weight = 1 - np.array(train_buffer.percentages_rewards())
    actions_weight = 1 - np.array(train_buffer.percentages_actions())
    actions_weight = torch.FloatTensor(actions_weight).to(DEVICE)
    rewards_weight = torch.FloatTensor(rewards_weight).to(DEVICE)
    info['actions_weight'] = actions_weight
    info['rewards_weight'] = rewards_weight

    # output mixtures should be 2*nr_logistic_mix + nr_logistic mix for each
    # decorelated channel
    #info['num_output_mixtures']= (2*args.nr_logistic_mix+args.nr_logistic_mix)*info['HISTORY_SIZE']
    #nmix = int(info['num_output_mixtures']/info['HISTORY_SIZE'])
    #info['nmix'] = nmix
    encoder_model = ConvVAE(
        info['CODE_LENGTH'],
        input_size=args.num_condition,
        encoder_output_size=args.encoder_output_size,
    ).to(DEVICE)
    prior_model = PriorNetwork(
        size_training_set=info['NUM_TRAINING_EXAMPLES'],
        code_length=info['CODE_LENGTH'],
        n_mixtures=info['NUM_MIXTURES'],
        k=info['NUM_K'],
        require_unique_codes=info['REQUIRE_UNIQUE_CODES'],
    ).to(DEVICE)
    pcnn_decoder = GatedPixelCNN(input_dim=1,
                                 dim=info['NUM_PCNN_FILTERS'],
                                 n_layers=info['NUM_PCNN_LAYERS'],
                                 n_classes=info['num_actions'],
                                 float_condition_size=info['CODE_LENGTH'],
                                 last_layer_bias=0.5,
                                 hsize=info['hsize'],
                                 wsize=info['wsize']).to(DEVICE)

    #parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
    parameters = list(encoder_model.parameters()) + list(
        prior_model.parameters())
    opt = optim.Adam(parameters, lr=info['MODEL_LEARNING_RATE'])

    if args.model_loadpath != '':
        print("loading weights from:%s" % args.model_loadpath)
        encoder_model.load_state_dict(model_dict['encoder_model_state_dict'])
        prior_model.load_state_dict(model_dict['prior_model_state_dict'])
        pcnn_decoder.load_state_dict(model_dict['pcnn_decoder_state_dict'])
        #encoder_model.embedding = model_dict['model_embedding']
        opt.load_state_dict(model_dict['opt_state_dict'])

    model_dict = {
        'encoder_model': encoder_model,
        'prior_model': prior_model,
        'pcnn_decoder': pcnn_decoder,
        'opt': opt
    }
    data_buffers = {'train': train_buffer, 'valid': valid_buffer}
    if args.sample:
        sample_acn(info,
                   model_dict,
                   data_buffers,
                   num_samples=args.num_samples,
                   teacher_force=args.teacher_force)
    else:
        train_acn(info, model_dict, data_buffers)
    #    print(largs.possible_values)
    #except:
    #    largs.possible_values = 1
    #try:
    #    print(largs.num_pcnn_layers)
    #except:
    #    largs.num_pcnn_layers = 12
    #try:
    #    print(largs.num_classes)
    #except:
    #    largs.num_classes = 10

    size_training_set = len(train_data)

    encoder_model = ConvVAE(largs.code_length,
                            input_size=1,
                            encoder_output_size=largs.encoder_output_size)
    encoder_model.load_state_dict(model_dict['vae_state_dict'])

    prior_model = PriorNetwork(size_training_set=largs.size_training_set,
                               code_length=largs.code_length,
                               k=largs.num_k).to(DEVICE)
    prior_model.codes = model_dict['codes']

    pcnn_decoder = GatedPixelCNN(input_dim=1,
                                 dim=largs.possible_values,
                                 n_layers=largs.num_pcnn_layers,
                                 n_classes=largs.num_classes,
                                 float_condition_size=largs.code_length,
                                 last_layer_bias=0.5,
                                 hsize=hsize,