예제 #1
0
    plot_kldiv_loss = [None for x in range(total_iter)]
    dyn_plot = DynamicPlot(title='Training loss over epochs (GRASS)',
                           xdata=plot_x,
                           ydata={
                               'Total_loss': plot_total_loss,
                               'Reconstruction_loss': plot_recon_loss,
                               'KL_divergence_loss': plot_kldiv_loss
                           })
    iter_id = 0
    max_loss = 0

for epoch in range(config.epochs):
    print(header)
    for batch_idx, batch in enumerate(train_iter):
        # Initialize torchfold for *encoding*
        enc_fold = FoldExt(cuda=config.cuda)
        enc_fold_nodes = []  # list of fold nodes for encoding
        # Collect computation nodes recursively from encoding process
        for example in batch:
            enc_fold_nodes.append(
                grassmodel.encode_structure_fold(enc_fold, example))
        # Apply the computations on the encoder model
        enc_fold_nodes = enc_fold.apply(encoder, [enc_fold_nodes])
        # Split into a list of fold nodes per example
        enc_fold_nodes = torch.split(enc_fold_nodes[0], 1, 0)
        # Initialize torchfold for *decoding*
        dec_fold = FoldExt(cuda=config.cuda)
        # Collect computation nodes recursively from decoding process
        dec_fold_nodes = []
        kld_fold_nodes = []
        for example, fnode in zip(batch, enc_fold_nodes):
예제 #2
0
파일: train.py 프로젝트: yuangan/PartNet
total_iter = config.epochs * len(train_iter)
step = 0
for epoch in range(config.epochs):
    scheduler.step()
    print(header)
    for batch_idx, batch in enumerate(train_iter):
        # compute points feature
        input_data = batch[1].cuda()
        jitter_input = torch.randn(input_data.size()).cuda()
        jitter_input = torch.clamp(0.01 * jitter_input, min=-0.05, max=0.05)
        jitter_input += input_data
        points_feature = net.pointnet(jitter_input)
        # Split into a list of fold nodes per example
        enc_points_feature = torch.split(points_feature, 1, 0)
        # Initialize torchfold for *decoding*
        dec_fold = FoldExt(cuda=config.cuda)
        # Collect computation nodes recursively from decoding process
        dec_fold_nodes_label = []
        dec_fold_nodes_box = []
        dec_fold_nodes_acc = []
        for example, points_f in zip(batch[0], enc_points_feature):
            labelloss, boxloss, acc = partnet_model.decode_structure_fold(
                dec_fold, example, points_f)
            dec_fold_nodes_label.append(labelloss)
            dec_fold_nodes_box.append(boxloss)
            dec_fold_nodes_acc.append(acc)
        # Apply the computations on the decoder model
        dec_loss = dec_fold.apply(
            net,
            [dec_fold_nodes_label, dec_fold_nodes_box, dec_fold_nodes_acc])
        num_nodes = torch.cat([x.n_nodes for x in batch[0]], 0)
예제 #3
0
    fd_log.flush()

header = '     Time    Epoch     Iteration    Progress(%)    LR       LabelLoss  TotalLoss'
log_template = ' '.join(
    '{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.9f},{:>10.4f},{:>10.4f}'
    .split(','))
print(header)
total_iter = config.epochs * len(train_iter)

learning_rate = config.lr
count = 0
for epoch in range(config.epochs):
    encoder_decoder_opt = torch.optim.Adam(leaf_classification.parameters(),
                                           lr=learning_rate)
    for batch_idx, batch in enumerate(train_iter):
        enc_fold = FoldExt(cuda=config.cuda)
        enc_dec_fold_nodes = []
        for example in batch:
            enc_dec_fold_nodes.append(
                leafclassificationmodel.leaf_classification_fold(
                    enc_fold, example))
        total_loss = enc_fold.apply(leaf_classification, [enc_dec_fold_nodes])

        label_loss = total_loss[0].sum() / len(batch)

        encoder_decoder_opt.zero_grad()
        label_loss.backward()
        encoder_decoder_opt.step()

        # Report statistics
        if batch_idx % config.show_log_every == 0:
예제 #4
0
            'params': encoder_decoder.adj_decoder.parameters()
        }, {
            'params': encoder_decoder.sym_decoder.parameters()
        }, {
            'params': encoder_decoder.sample_decoder.parameters()
        }, {
            'params': encoder_decoder.box_encoder.encoder.parameters(),
            'lr': learning_rate * 0.01
        }, {
            'params': encoder_decoder.node_classifier.parameters(),
            'lr': learning_rate * 0.01
        }],
        lr=learning_rate)

    for batch_idx, batch in enumerate(train_iter):
        enc_fold = FoldExt(cuda=config.cuda)
        enc_dec_fold_nodes = []
        enc_dec_recon_fold_nodes = []
        enc_dec_label_fold_nodes = []
        for example in batch:
            enc_dec_fold_nodes.append(
                grassmodel.encode_decode_structure_fold(enc_fold, example))
            enc_dec_recon_fold_nodes.append(
                grassmodel.encode_decode_recon_structure_fold(
                    enc_fold, example))
            enc_dec_label_fold_nodes.append(
                grassmodel.encode_decode_label_structure_fold(
                    enc_fold, example))

        total_loss = enc_fold.apply(encoder_decoder, [
            enc_dec_fold_nodes, enc_dec_recon_fold_nodes,