def valid_acn(train_cnt, do_plot): valid_kl_loss = 0.0 valid_rec_loss = 0.0 print('starting valid', train_cnt) st = time.time() valid_cnt = 0 encoder_model.eval() prior_model.eval() pcnn_decoder.eval() opt.zero_grad() i = 0 states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = valid_data_loader.get_unique_minibatch() states = states.to(DEVICE) # 1 channel expected next_states = next_states[:,args.number_condition-1:].to(DEVICE) actions = actions.to(DEVICE) z, u_q = encoder_model(states) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('baad') embed() #yhat_batch = encoder_model.decode(u_q, s_q, data) # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z)) mix, u_ps, s_ps = prior_model(u_q) kl_loss,rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps) valid_kl_loss+= kl_loss.item() valid_rec_loss+= rec_loss.item() valid_cnt += states.shape[0] if i == 0 and do_plot: print('writing img') n_imgs = 8 n = min(states.shape[0], n_imgs) #onext_states = torch.Tensor(next_states[:n].data.cpu().numpy()+train_data_loader.frames_mean)#*train_data_loader.frames_diff) + train_data_loader.frames_min) #oyhat_batch = torch.Tensor( yhat_batch[:n].data.cpu().numpy()+train_data_loader.frames_mean)#*train_data_loader.frames_diff) + train_data_loader.frames_min) #onext_states = torch.Tensor(((next_states[:n].data.cpu().numpy()*train_data_loader.frames_diff)+train_data_loader.frames_min) + train_data_loader.frames_mean)/255. #oyhat_batch = torch.Tensor((( yhat_batch[:n].data.cpu().numpy()*train_data_loader.frames_diff)+train_data_loader.frames_min) + train_data_loader.frames_mean)/255. bs = args.batch_size h = train_data_loader.data_h w = train_data_loader.data_w comparison = torch.cat([next_states.view(bs,1,h,w)[:n], yhat_batch.view(bs,1,h,w)[:n]]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png"%train_cnt save_image(comparison, img_name, nrow=n) #embed() #ocomparison = torch.cat([onext_states, # oyhat_batch]) #img_name = model_base_filepath + "_%010d_valid_reconstructionMINE.png"%train_cnt #save_image(ocomparison, img_name, nrow=n) #embed() print('finished writing img', img_name) valid_kl_loss/=float(valid_cnt) valid_rec_loss/=float(valid_cnt) print('====> valid kl loss: {:.4f}'.format(valid_kl_loss)) print('====> valid rec loss: {:.4f}'.format(valid_rec_loss)) print('finished valid', time.time()-st) return valid_kl_loss, valid_rec_loss
def test_acn(train_cnt, do_plot): test_kl_loss = 0.0 test_rec_loss = 0.0 print('starting test', train_cnt) st = time.time() test_cnt = 0 encoder_model.eval() prior_model.eval() pcnn_decoder.eval() opt.zero_grad() i = 0 data, label, data_index = data_loader.validation_data() lst = time.time() data = data.to(DEVICE) label = label.to(DEVICE) z, u_q = encoder_model(data) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('baad') embed() #yhat_batch = encoder_model.decode(u_q, s_q, data) # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z)) mix, u_ps, s_ps = prior_model(u_q) #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures, u_ps, s_ps) kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix, u_ps, s_ps) #loss = acn_loss_function(yhat_batch, data, u_q, u_p, s_p) test_kl_loss += kl_loss.item() test_rec_loss += rec_loss.item() #test_loss+= loss.item() test_cnt += data.shape[0] if i == 0 and do_plot: print('writing img') n = min(data.size(0), 8) bs = data.shape[0] comparison = torch.cat([ label.view(bs, 1, hsize, wsize)[:n], yhat_batch.view(bs, 1, hsize, wsize)[:n] ]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison.cpu(), img_name, nrow=n) print('finished writing img', img_name) #print('loop test', i, time.time()-lst) #test_loss /= float(test_cnt) test_kl_loss /= float(test_cnt) test_rec_loss /= float(test_cnt) print('====> Test kl loss: {:.4f}'.format(test_kl_loss)) print('====> Test rec loss: {:.4f}'.format(test_rec_loss)) print('finished test', time.time() - st) return test_kl_loss, test_rec_loss
def train_acn(train_cnt): train_kl_loss = 0.0 train_rec_loss = 0.0 init_cnt = train_cnt st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: encoder_model.train() prior_model.train() pcnn_decoder.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() states = states.to(DEVICE) # 1 channel expected next_states = next_states[:,args.number_condition-1:].to(DEVICE) actions = actions.to(DEVICE) z, u_q = encoder_model(states) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z)) #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, float_condition=z)) #print(train_cnt) prior_model.codes[relative_indexes-args.number_condition] = u_q.detach().cpu().numpy() np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() mix, u_ps, s_ps = prior_model(u_q) kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss loss.backward() parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_kl_loss+= kl_loss.item() train_rec_loss+= rec_loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_kl_loss = train_kl_loss/float((train_cnt+states.shape[0])-init_cnt) avg_train_rec_loss = train_rec_loss/float((train_cnt+states.shape[0])-init_cnt) handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss) train_cnt+=len(states) batches+=1 if not batches%1000: print("finished %s epoch after %s seconds at cnt %s"%(batches, time.time()-st, train_cnt)) return train_cnt
def train_acn(train_cnt): train_loss = 0 init_cnt = train_cnt st = time.time() for batch_idx, (data, label, data_index) in enumerate(train_loader): encoder_model.train() prior_model.train() pcnn_decoder.train() lst = time.time() data = data.to(DEVICE) opt.zero_grad() z, u_q = encoder_model(data) #yhat_batch = encoder_model.decode(u_q, s_q, data) # add the predicted codes to the input # TODO - this isn't how you sample pcnn yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z)) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() prior_model.codes[data_index] = np_uq #prior_model.fit_knn(prior_model.codes) # output is gmp mixtures, u_ps, s_ps = prior_model(u_q) kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q, mixtures, u_ps, s_ps) loss = kl_reg + rec_loss if not batch_idx % 10: print(train_cnt, batch_idx, kl_reg.item(), rec_loss.item()) loss.backward() parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_loss += loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_loss = train_loss / float((train_cnt + data.shape[0]) - init_cnt) print('batch', train_cnt, avg_train_loss) handle_checkpointing(train_cnt, avg_train_loss) train_cnt += len(data) print(train_loss) print("finished epoch after %s seconds at cnt %s" % (time.time() - st, train_cnt)) return train_cnt
def test_acn(train_cnt, do_plot): encoder_model.eval() prior_model.eval() pcnn_decoder.eval() test_loss = 0 print('starting test', train_cnt) st = time.time() print(len(test_loader)) #with torch.no_grad(): for i, (data, label, data_index) in enumerate(test_loader): lst = time.time() data = data.to(DEVICE) z, u_q = encoder_model(data) np_uq = u_q.detach().cpu().numpy() #yhat_batch = encoder_model.decode(u_q, s_q, data) # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z)) if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('baad') embed() mixtures, u_ps, s_ps = prior_model(u_q) kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q, mixtures, u_ps, s_ps) loss = kl_reg + rec_loss #loss = acn_loss_function(yhat_batch, data, u_q, u_p, s_p) test_loss += loss.item() if i == 0 and do_plot: print('writing img') n = min(data.size(0), 8) bs = data.shape[0] comparison = torch.cat([ data.view(bs, 1, 28, 28)[:n], yhat_batch.view(bs, 1, 28, 28)[:n] ]) img_name = model_base_filepath + "_%010d_valid_reconstruction.png" % train_cnt save_image(comparison.cpu(), img_name, nrow=n) print('finished writing img', img_name) #print('loop test', i, time.time()-lst) test_loss /= len(test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss)) print('finished test', time.time() - st) return test_loss
def train_acn(train_cnt): train_kl_loss = 0.0 train_rec_loss = 0.0 init_cnt = train_cnt st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: encoder_model.train() prior_model.train() pcnn_decoder.train() opt.zero_grad() lst = time.time() data, label, data_index, is_new_epoch = data_loader.next_unique_batch() if is_new_epoch: # prior_model.new_epoch() print(train_cnt, 'train, is new epoch', prior_model.available_indexes.shape) data = data.to(DEVICE) label = label.to(DEVICE) # inf happens sometime after 0001,680,896 z, u_q = encoder_model(data) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z)) #print(train_cnt) prior_model.codes[data_index - args.number_condition] = u_q.detach().cpu().numpy() #mixtures, u_ps, s_ps = prior_model(u_q) #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures, u_ps, s_ps) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() #loss.backward() # parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) # clip_grad_value_(parameters, 10) # train_loss+= loss.item() mix, u_ps, s_ps = prior_model(u_q) #kl_loss, rec_loss = acn_loss_function(yhat_batch, data, u_q, u_ps, s_ps) kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss loss.backward() parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_kl_loss += kl_loss.item() train_rec_loss += rec_loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_kl_loss = train_kl_loss / float((train_cnt + data.shape[0]) - init_cnt) avg_train_rec_loss = train_rec_loss / float( (train_cnt + data.shape[0]) - init_cnt) handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss) train_cnt += len(data) # add batch size because it hasn't been added to train cnt yet # avg_train_loss = train_loss/float((train_cnt+data.shape[0])-init_cnt) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def train_acn(info, model_dict, data_buffers, phase='train'): encoder_model = model_dict['encoder_model'] prior_model = model_dict['prior_model'] pcnn_decoder = model_dict['pcnn_decoder'] opt = model_dict['opt'] # add one to the rewards so that they are all positive # use next_states because that is the t-1 action if len(info['model_train_cnts']): train_cnt = info['model_train_cnts'][-1] else: train_cnt = 0 num_batches = data_buffers['train'].count // info['MODEL_BATCH_SIZE'] while train_cnt < 10000000: if phase == 'valid': encoder_model.eval() prior_model.eval() pcnn_decoder.eval() else: encoder_model.train() prior_model.train() pcnn_decoder.train() batch_num = 0 data_buffers[phase].reset_unique() print('-------------new epoch %s------------------' % phase) print('num batches', num_batches) while data_buffers[phase].unique_available: opt.zero_grad() batch = data_buffers[phase].get_unique_minibatch( info['MODEL_BATCH_SIZE']) relative_indices = batch[-1] states, actions, rewards, next_states = make_state( batch[:-1], info['DEVICE'], info['NORM_BY']) next_state = next_states[:, -1:] bs = states.shape[0] #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() z, u_q = encoder_model(states) # add the predicted codes to the input #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_state, # class_condition=actions, # float_condition=z)) yhat_batch = encoder_model.decode(z) prior_model.codes[relative_indices] = u_q.detach().cpu().numpy() mix, u_ps, s_ps = prior_model(u_q) # track losses kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_state, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss # aatch size because it hasn't been added to train cnt yet if not phase == 'valid': loss.backward() #parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) clip_grad_value_(parameters, 10) opt.step() train_cnt += bs if not batch_num % info['MODEL_LOG_EVERY_BATCHES']: print(phase, train_cnt, batch_num, kl_loss.item(), rec_loss.item()) info = add_losses(info, train_cnt, phase, kl_loss.item(), rec_loss.item()) batch_num += 1 if (((train_cnt - info['model_last_save']) >= info['MODEL_SAVE_EVERY'])): info = add_losses(info, train_cnt, phase, kl_loss.item(), rec_loss.item()) if phase == 'train': # run as valid phase and get back to here phase = 'valid' else: model_dict = { 'encoder_model': encoder_model, 'prior_model': prior_model, 'pcnn_decoder': pcnn_decoder, 'opt': opt } info = save_model(info, model_dict) phase = 'train' model_dict = { 'encoder_model': encoder_model, 'prior_model': prior_model, 'pcnn_decoder': pcnn_decoder, 'opt': opt } info = save_model(info, model_dict)