def valid_acn(train_cnt, do_plot):
    valid_kl_loss = 0.0
    valid_rec_loss = 0.0
    print('starting valid', train_cnt)
    st = time.time()
    valid_cnt = 0
    encoder_model.eval()
    prior_model.eval()
    pcnn_decoder.eval()
    opt.zero_grad()
    i = 0
    states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch()
    states = states.to(DEVICE)
    # 1 channel expected
    next_states = next_states[:,args.number_condition-1:].to(DEVICE)
    actions = actions.to(DEVICE)
    z, u_q = encoder_model(states)

    np_uq = u_q.detach().cpu().numpy()
    if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
        print('baad')
        embed()

    #yhat_batch = encoder_model.decode(u_q, s_q, data)
    # add the predicted codes to the input
    yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z))
    mix, u_ps, s_ps = prior_model(u_q)
    kl_loss,rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps)
    valid_kl_loss+= kl_loss.item()
    valid_rec_loss+= rec_loss.item()
    valid_cnt += states.shape[0]
    if i == 0 and do_plot:
        print('writing img')
        n_imgs = 8
        n = min(states.shape[0], n_imgs)
        #onext_states = torch.Tensor(next_states[:n].data.cpu().numpy()+train_data_loader.frames_mean)#*train_data_loader.frames_diff) + train_data_loader.frames_min)
        #oyhat_batch =  torch.Tensor( yhat_batch[:n].data.cpu().numpy()+train_data_loader.frames_mean)#*train_data_loader.frames_diff) + train_data_loader.frames_min)
        #onext_states = torch.Tensor(((next_states[:n].data.cpu().numpy()*train_data_loader.frames_diff)+train_data_loader.frames_min) + train_data_loader.frames_mean)/255.
        #oyhat_batch =  torch.Tensor((( yhat_batch[:n].data.cpu().numpy()*train_data_loader.frames_diff)+train_data_loader.frames_min) + train_data_loader.frames_mean)/255.
        bs = args.batch_size
        h = train_data_loader.data_h
        w = train_data_loader.data_w
        comparison = torch.cat([next_states.view(bs,1,h,w)[:n],
                                yhat_batch.view(bs,1,h,w)[:n]])
        img_name = model_base_filepath + "_%010d_valid_reconstruction.png"%train_cnt
        save_image(comparison, img_name, nrow=n)
        #embed()
        #ocomparison = torch.cat([onext_states,
        #                        oyhat_batch])
        #img_name = model_base_filepath + "_%010d_valid_reconstructionMINE.png"%train_cnt
        #save_image(ocomparison, img_name, nrow=n)
        #embed()
        print('finished writing img', img_name)
    valid_kl_loss/=float(valid_cnt)
    valid_rec_loss/=float(valid_cnt)
    print('====> valid kl loss: {:.4f}'.format(valid_kl_loss))
    print('====> valid rec loss: {:.4f}'.format(valid_rec_loss))
    print('finished valid', time.time()-st)
    return valid_kl_loss, valid_rec_loss
def test_acn(train_cnt, do_plot):
    test_kl_loss = 0.0
    test_rec_loss = 0.0

    print('starting test', train_cnt)
    st = time.time()
    test_cnt = 0
    encoder_model.eval()
    prior_model.eval()
    pcnn_decoder.eval()
    opt.zero_grad()
    i = 0
    data, label, data_index = data_loader.validation_data()
    lst = time.time()
    data = data.to(DEVICE)
    label = label.to(DEVICE)
    z, u_q = encoder_model(data)

    np_uq = u_q.detach().cpu().numpy()
    if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
        print('baad')
        embed()

    #yhat_batch = encoder_model.decode(u_q, s_q, data)
    # add the predicted codes to the input
    yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z))
    mix, u_ps, s_ps = prior_model(u_q)
    #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures,  u_ps, s_ps)
    kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix,
                                              u_ps, s_ps)
    #loss = acn_loss_function(yhat_batch, data, u_q, u_p, s_p)
    test_kl_loss += kl_loss.item()
    test_rec_loss += rec_loss.item()
    #test_loss+= loss.item()
    test_cnt += data.shape[0]
    if i == 0 and do_plot:
        print('writing img')
        n = min(data.size(0), 8)
        bs = data.shape[0]
        comparison = torch.cat([
            label.view(bs, 1, hsize, wsize)[:n],
            yhat_batch.view(bs, 1, hsize, wsize)[:n]
        ])
        img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
        save_image(comparison.cpu(), img_name, nrow=n)
        print('finished writing img', img_name)
    #print('loop test', i, time.time()-lst)

    #test_loss /= float(test_cnt)
    test_kl_loss /= float(test_cnt)
    test_rec_loss /= float(test_cnt)
    print('====> Test kl loss: {:.4f}'.format(test_kl_loss))
    print('====> Test rec loss: {:.4f}'.format(test_rec_loss))
    print('finished test', time.time() - st)
    return test_kl_loss, test_rec_loss
def train_acn(train_cnt):
    train_kl_loss = 0.0
    train_rec_loss = 0.0
    init_cnt = train_cnt
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        opt.zero_grad()
        states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
        states = states.to(DEVICE)
        # 1 channel expected
        next_states = next_states[:,args.number_condition-1:].to(DEVICE)
        actions = actions.to(DEVICE)
        z, u_q = encoder_model(states)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()

        # add the predicted codes to the input
        yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z))
        #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, float_condition=z))
        #print(train_cnt)
        prior_model.codes[relative_indexes-args.number_condition] = u_q.detach().cpu().numpy()
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        mix, u_ps, s_ps = prior_model(u_q)
        kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps)
        loss = kl_loss + rec_loss
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_kl_loss+= kl_loss.item()
        train_rec_loss+= rec_loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_kl_loss = train_kl_loss/float((train_cnt+states.shape[0])-init_cnt)
        avg_train_rec_loss = train_rec_loss/float((train_cnt+states.shape[0])-init_cnt)
        handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss)
        train_cnt+=len(states)

        batches+=1
        if not batches%1000:
            print("finished %s epoch after %s seconds at cnt %s"%(batches, time.time()-st, train_cnt))
    return train_cnt
def train_acn(train_cnt):
    train_loss = 0
    init_cnt = train_cnt
    st = time.time()
    for batch_idx, (data, label, data_index) in enumerate(train_loader):
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        lst = time.time()
        data = data.to(DEVICE)
        opt.zero_grad()
        z, u_q = encoder_model(data)
        #yhat_batch = encoder_model.decode(u_q, s_q, data)
        # add the predicted codes to the input
        # TODO - this isn't how you sample pcnn
        yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z))

        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        prior_model.codes[data_index] = np_uq
        #prior_model.fit_knn(prior_model.codes)
        # output is gmp
        mixtures, u_ps, s_ps = prior_model(u_q)
        kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q,
                                                 mixtures, u_ps, s_ps)
        loss = kl_reg + rec_loss
        if not batch_idx % 10:
            print(train_cnt, batch_idx, kl_reg.item(), rec_loss.item())
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(
            prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_loss += loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_loss = train_loss / float((train_cnt + data.shape[0]) -
                                            init_cnt)
        print('batch', train_cnt, avg_train_loss)
        handle_checkpointing(train_cnt, avg_train_loss)
        train_cnt += len(data)
    print(train_loss)
    print("finished epoch after %s seconds at cnt %s" %
          (time.time() - st, train_cnt))
    return train_cnt
def test_acn(train_cnt, do_plot):
    encoder_model.eval()
    prior_model.eval()
    pcnn_decoder.eval()
    test_loss = 0
    print('starting test', train_cnt)
    st = time.time()
    print(len(test_loader))
    #with torch.no_grad():
    for i, (data, label, data_index) in enumerate(test_loader):
        lst = time.time()
        data = data.to(DEVICE)
        z, u_q = encoder_model(data)
        np_uq = u_q.detach().cpu().numpy()
        #yhat_batch = encoder_model.decode(u_q, s_q, data)
        # add the predicted codes to the input
        yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z))
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('baad')
            embed()
        mixtures, u_ps, s_ps = prior_model(u_q)
        kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q,
                                                 mixtures, u_ps, s_ps)
        loss = kl_reg + rec_loss
        #loss = acn_loss_function(yhat_batch, data, u_q, u_p, s_p)
        test_loss += loss.item()
        if i == 0 and do_plot:
            print('writing img')
            n = min(data.size(0), 8)
            bs = data.shape[0]
            comparison = torch.cat([
                data.view(bs, 1, 28, 28)[:n],
                yhat_batch.view(bs, 1, 28, 28)[:n]
            ])
            img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt
            save_image(comparison.cpu(), img_name, nrow=n)
            print('finished writing img', img_name)
        #print('loop test', i, time.time()-lst)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    print('finished test', time.time() - st)
    return test_loss
def train_acn(train_cnt):
    train_kl_loss = 0.0
    train_rec_loss = 0.0
    init_cnt = train_cnt
    st = time.time()
    #for batch_idx, (data, label, data_index) in enumerate(train_loader):
    batches = 0
    while train_cnt < args.num_examples_to_train:
        encoder_model.train()
        prior_model.train()
        pcnn_decoder.train()
        opt.zero_grad()
        lst = time.time()
        data, label, data_index, is_new_epoch = data_loader.next_unique_batch()
        if is_new_epoch:
            #    prior_model.new_epoch()
            print(train_cnt, 'train, is new epoch',
                  prior_model.available_indexes.shape)
        data = data.to(DEVICE)
        label = label.to(DEVICE)
        #  inf happens sometime after 0001,680,896
        z, u_q = encoder_model(data)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()

        # add the predicted codes to the input
        yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z))
        #print(train_cnt)
        prior_model.codes[data_index -
                          args.number_condition] = u_q.detach().cpu().numpy()
        #mixtures, u_ps, s_ps = prior_model(u_q)
        #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures, u_ps, s_ps)
        np_uq = u_q.detach().cpu().numpy()
        if np.isinf(np_uq).sum() or np.isnan(np_uq).sum():
            print('train bad')
            embed()
        #loss.backward()


#        parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
#        clip_grad_value_(parameters, 10)
#        train_loss+= loss.item()
        mix, u_ps, s_ps = prior_model(u_q)
        #kl_loss, rec_loss = acn_loss_function(yhat_batch, data, u_q, u_ps, s_ps)
        kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix,
                                                  u_ps, s_ps)
        loss = kl_loss + rec_loss
        loss.backward()
        parameters = list(encoder_model.parameters()) + list(
            prior_model.parameters()) + list(pcnn_decoder.parameters())
        clip_grad_value_(parameters, 10)
        train_kl_loss += kl_loss.item()
        train_rec_loss += rec_loss.item()
        opt.step()
        # add batch size because it hasn't been added to train cnt yet
        avg_train_kl_loss = train_kl_loss / float((train_cnt + data.shape[0]) -
                                                  init_cnt)
        avg_train_rec_loss = train_rec_loss / float(
            (train_cnt + data.shape[0]) - init_cnt)
        handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss)
        train_cnt += len(data)

        # add batch size because it hasn't been added to train cnt yet
        #        avg_train_loss = train_loss/float((train_cnt+data.shape[0])-init_cnt)
        batches += 1
        if not batches % 1000:
            print("finished %s epoch after %s seconds at cnt %s" %
                  (batches, time.time() - st, train_cnt))
    return train_cnt
def train_acn(info, model_dict, data_buffers, phase='train'):
    encoder_model = model_dict['encoder_model']
    prior_model = model_dict['prior_model']
    pcnn_decoder = model_dict['pcnn_decoder']
    opt = model_dict['opt']

    # add one to the rewards so that they are all positive
    # use next_states because that is the t-1 action

    if len(info['model_train_cnts']):
        train_cnt = info['model_train_cnts'][-1]
    else:
        train_cnt = 0

    num_batches = data_buffers['train'].count // info['MODEL_BATCH_SIZE']
    while train_cnt < 10000000:
        if phase == 'valid':
            encoder_model.eval()
            prior_model.eval()
            pcnn_decoder.eval()
        else:
            encoder_model.train()
            prior_model.train()
            pcnn_decoder.train()

        batch_num = 0
        data_buffers[phase].reset_unique()
        print('-------------new epoch %s------------------' % phase)
        print('num batches', num_batches)
        while data_buffers[phase].unique_available:
            opt.zero_grad()
            batch = data_buffers[phase].get_unique_minibatch(
                info['MODEL_BATCH_SIZE'])
            relative_indices = batch[-1]
            states, actions, rewards, next_states = make_state(
                batch[:-1], info['DEVICE'], info['NORM_BY'])
            next_state = next_states[:, -1:]
            bs = states.shape[0]
            #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch()
            z, u_q = encoder_model(states)

            # add the predicted codes to the input
            #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_state,
            #                                        class_condition=actions,
            #                                        float_condition=z))
            yhat_batch = encoder_model.decode(z)
            prior_model.codes[relative_indices] = u_q.detach().cpu().numpy()

            mix, u_ps, s_ps = prior_model(u_q)

            # track losses
            kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_state,
                                                      u_q, mix, u_ps, s_ps)
            loss = kl_loss + rec_loss
            # aatch size because it hasn't been added to train cnt yet

            if not phase == 'valid':
                loss.backward()
                #parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters())
                parameters = list(encoder_model.parameters()) + list(
                    prior_model.parameters())
                clip_grad_value_(parameters, 10)
                opt.step()
                train_cnt += bs

            if not batch_num % info['MODEL_LOG_EVERY_BATCHES']:
                print(phase, train_cnt, batch_num, kl_loss.item(),
                      rec_loss.item())
                info = add_losses(info, train_cnt, phase, kl_loss.item(),
                                  rec_loss.item())
            batch_num += 1

        if (((train_cnt - info['model_last_save']) >=
             info['MODEL_SAVE_EVERY'])):
            info = add_losses(info, train_cnt, phase, kl_loss.item(),
                              rec_loss.item())
            if phase == 'train':
                # run as valid phase and get back to here
                phase = 'valid'
            else:
                model_dict = {
                    'encoder_model': encoder_model,
                    'prior_model': prior_model,
                    'pcnn_decoder': pcnn_decoder,
                    'opt': opt
                }
                info = save_model(info, model_dict)
                phase = 'train'

    model_dict = {
        'encoder_model': encoder_model,
        'prior_model': prior_model,
        'pcnn_decoder': pcnn_decoder,
        'opt': opt
    }

    info = save_model(info, model_dict)