def train(train_data_loader, network, optimizer, writer, jter_count): data_fetcher = train_data_loader.get_data() count = 0 for jter, (data, data_gt) in enumerate(data_fetcher): cat_target = np.concatenate(data_gt, axis=0) x_gt = torch.autograd.Variable( torch.FloatTensor(cat_target[:, 1]).cuda()) y_gt = torch.autograd.Variable( torch.FloatTensor(cat_target[:, 2]).cuda()) pen_down_gt = torch.autograd.Variable( torch.LongTensor(cat_target[:, 0]).cuda()) output = network.forward_unlooped(data, cuda=True) pen_down_prob, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr = network.go.get_mixture_coef( output) loss_distr, pen_loss = network.go.loss_distr(pen_down_prob, o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, x_gt, y_gt, pen_down_gt) total_loss = loss_distr + pen_loss # compute gradient and do SGD step optimizer.zero_grad() total_loss.backward() optimizer.step() # Summarization utils.write_summaries(['data/LossDistr', 'data/Lossloss'], [loss_distr, pen_loss], [0, 0], writer, jter + jter_count) if jter % gv.update_step == 0: print('cur_iter', jter, "cur_loss", loss_distr.cpu().data.numpy(), 'batch_size', len(data)) # add the handwriting pred and gt first_seq_len = data[0].shape[0] output = output[:first_seq_len] predicted_action = network.go.sample_action( pen_down_prob[:first_seq_len], o_pi[:first_seq_len], o_mu1[:first_seq_len], o_mu2[:first_seq_len], o_sigma1[:first_seq_len], o_sigma2[:first_seq_len], o_corr[:first_seq_len]) loss_l2, pen_acc = network.go.val_loss( predicted_action[:first_seq_len], x_gt[:first_seq_len], y_gt[:first_seq_len], pen_down_gt[:first_seq_len]) pred_image = utils.plot_stroke_numpy( predicted_action.cpu().numpy()) # .transpose(2,0,1) gt_image = utils.plot_stroke_numpy(data_gt[0]) # .transpose(2,0,1) name_list = [ 'data/L2Dist', 'data/PenAcc', 'train/PredictedSeq', 'train/GTSeq' ] value_list = [loss_l2, pen_acc, pred_image, gt_image] utils.write_summaries(name_list, value_list, [0, 0, 1, 1], writer, jter + jter_count) jter_count = jter_count + jter return jter_count
def val(val_data_loader,network,writer,jter_count): data_fetcher = val_data_loader.get_data_single() loss_list = [] for jter,(data,data_gt) in enumerate(data_fetcher): cat_target = np.concatenate(data_gt,axis=0) x_gt = torch.autograd.Variable(torch.FloatTensor(cat_target[:,1]).cuda()) y_gt = torch.autograd.Variable(torch.FloatTensor(cat_target[:,2]).cuda()) pen_down_gt = torch.autograd.Variable(torch.LongTensor(cat_target[:,0]).cuda()) output = network.forward_unlooped(data,cuda=True) pen_down_prob,o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr = network.go.get_mixture_coef(output) loss_distr,pen_loss = network.go.loss_distr(pen_down_prob,o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, x_gt, y_gt,pen_down_gt) predicted_action = network.go.sample_action(pen_down_prob,o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr) loss_l2,pen_acc = network.go.val_loss(predicted_action, x_gt, y_gt,pen_down_gt) loss_list.append(loss_l2.cpu().data.numpy()) # Summarization name_list = ['data/LossDistr','data/Lossloss','data/L2Dist','data/PenAcc'] value_list = [loss_distr,pen_loss,loss_l2,pen_acc] utils.write_summaries(name_list,value_list, [0]*4, writer, jter+jter_count) if jter%gv.update_step ==0: pred_image = utils.plot_stroke_numpy(predicted_action.cpu().numpy())# .transpose(2,0,1) gt_image = utils.plot_stroke_numpy(data_gt[0])# .transpose(2,0,1) utils.write_summaries(['val/PredictedSeq','val/GTSeq'],[pred_image,gt_image],[1,1],writer,jter+jter_count) if jter%gv.update_step ==0: print('cur_iter',jter,"cur_loss",loss_l2.cpu().data.numpy(),'loss_pen',pen_acc.cpu().data.numpy()) loss = sum(loss_list)/float(jter+1) print('==========TOTAL VAL LOSS',loss," ====================") jter_count = jter+jter_count return loss,jter