plot_kldiv_loss = [None for x in range(total_iter)] dyn_plot = DynamicPlot(title='Training loss over epochs (GRASS)', xdata=plot_x, ydata={ 'Total_loss': plot_total_loss, 'Reconstruction_loss': plot_recon_loss, 'KL_divergence_loss': plot_kldiv_loss }) iter_id = 0 max_loss = 0 for epoch in range(config.epochs): print(header) for batch_idx, batch in enumerate(train_iter): # Initialize torchfold for *encoding* enc_fold = FoldExt(cuda=config.cuda) enc_fold_nodes = [] # list of fold nodes for encoding # Collect computation nodes recursively from encoding process for example in batch: enc_fold_nodes.append( grassmodel.encode_structure_fold(enc_fold, example)) # Apply the computations on the encoder model enc_fold_nodes = enc_fold.apply(encoder, [enc_fold_nodes]) # Split into a list of fold nodes per example enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0) # Initialize torchfold for *decoding* dec_fold = FoldExt(cuda=config.cuda) # Collect computation nodes recursively from decoding process dec_fold_nodes = [] kld_fold_nodes = [] for example, fnode in zip(batch, enc_fold_nodes):
total_iter = config.epochs * len(train_iter) step = 0 for epoch in range(config.epochs): scheduler.step() print(header) for batch_idx, batch in enumerate(train_iter): # compute points feature input_data = batch[1].cuda() jitter_input = torch.randn(input_data.size()).cuda() jitter_input = torch.clamp(0.01 * jitter_input, min=-0.05, max=0.05) jitter_input += input_data points_feature = net.pointnet(jitter_input) # Split into a list of fold nodes per example enc_points_feature = torch.split(points_feature, 1, 0) # Initialize torchfold for *decoding* dec_fold = FoldExt(cuda=config.cuda) # Collect computation nodes recursively from decoding process dec_fold_nodes_label = [] dec_fold_nodes_box = [] dec_fold_nodes_acc = [] for example, points_f in zip(batch[0], enc_points_feature): labelloss, boxloss, acc = partnet_model.decode_structure_fold( dec_fold, example, points_f) dec_fold_nodes_label.append(labelloss) dec_fold_nodes_box.append(boxloss) dec_fold_nodes_acc.append(acc) # Apply the computations on the decoder model dec_loss = dec_fold.apply( net, [dec_fold_nodes_label, dec_fold_nodes_box, dec_fold_nodes_acc]) num_nodes = torch.cat([x.n_nodes for x in batch[0]], 0)
fd_log.flush() header = ' Time Epoch Iteration Progress(%) LR LabelLoss TotalLoss' log_template = ' '.join( '{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.9f},{:>10.4f},{:>10.4f}' .split(',')) print(header) total_iter = config.epochs * len(train_iter) learning_rate = config.lr count = 0 for epoch in range(config.epochs): encoder_decoder_opt = torch.optim.Adam(leaf_classification.parameters(), lr=learning_rate) for batch_idx, batch in enumerate(train_iter): enc_fold = FoldExt(cuda=config.cuda) enc_dec_fold_nodes = [] for example in batch: enc_dec_fold_nodes.append( leafclassificationmodel.leaf_classification_fold( enc_fold, example)) total_loss = enc_fold.apply(leaf_classification, [enc_dec_fold_nodes]) label_loss = total_loss[0].sum() / len(batch) encoder_decoder_opt.zero_grad() label_loss.backward() encoder_decoder_opt.step() # Report statistics if batch_idx % config.show_log_every == 0:
'params': encoder_decoder.adj_decoder.parameters() }, { 'params': encoder_decoder.sym_decoder.parameters() }, { 'params': encoder_decoder.sample_decoder.parameters() }, { 'params': encoder_decoder.box_encoder.encoder.parameters(), 'lr': learning_rate * 0.01 }, { 'params': encoder_decoder.node_classifier.parameters(), 'lr': learning_rate * 0.01 }], lr=learning_rate) for batch_idx, batch in enumerate(train_iter): enc_fold = FoldExt(cuda=config.cuda) enc_dec_fold_nodes = [] enc_dec_recon_fold_nodes = [] enc_dec_label_fold_nodes = [] for example in batch: enc_dec_fold_nodes.append( grassmodel.encode_decode_structure_fold(enc_fold, example)) enc_dec_recon_fold_nodes.append( grassmodel.encode_decode_recon_structure_fold( enc_fold, example)) enc_dec_label_fold_nodes.append( grassmodel.encode_decode_label_structure_fold( enc_fold, example)) total_loss = enc_fold.apply(encoder_decoder, [ enc_dec_fold_nodes, enc_dec_recon_fold_nodes,