Exemple #1
0
    for i, batch in enumerate(data.train_data_loader):
        optimizer.zero_grad()
        if use_cuda:
            batch['x'][1] = batch['x'][1].cuda()
            batch['y'] = batch['y'].cuda()
        predictions = unet(batch['x'])
        loss = torch.nn.functional.cross_entropy(predictions, batch['y'])
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(epoch, 'Train loss', train_loss / (i + 1), 'MegaMulAdd=',
          scn.forward_pass_multiplyAdd_count / len(data.train) / 1e6,
          'MegaHidden', scn.forward_pass_hidden_states / len(data.train) / 1e6,
          'time=',
          time.time() - start, 's')
    scn.checkpoint_save(unet, exp_name, 'unet', epoch, use_cuda)

    if epoch % 100 == 1:
        with torch.no_grad():
            unet.eval()
            store = torch.zeros(data.valOffsets[-1], 20)
            scn.forward_pass_multiplyAdd_count = 0
            scn.forward_pass_hidden_states = 0
            start = time.time()
            for rep in range(1, 1 + data.val_reps):
                for i, batch in enumerate(data.val_data_loader):
                    if use_cuda:
                        batch['x'][1] = batch['x'][1].cuda()
                        batch['y'] = batch['y'].cuda()
                    predictions = unet(batch['x'])
                    store.index_add_(0, batch['point_ids'], predictions.cpu())
Exemple #2
0
        for param_group in optimizer.param_groups:  # can there be several param_groups?
            lrs.append(param_group["lr"])
        # lrs = optimizer.param_groups[0]["lr"]
        s = "Iteration: {}/{}, train loss = {:.5f}, time = {:4.1f}, lr = {}, points in batch = {}". \
            format(iteration, num_iterations, loss.item(), time.time() - iter_start, lrs, batch_num_points)
        print(s)
        logfile.write(s + '\n')
        iteration += 1
        # del batch

    s = "EPOCH: {}, train loss = {:.5f}, time = {:4.1f}, lr = {}\n".format(epoch, train_loss / (i + 1),
                                                                           time.time() - epoch_start, lrs)
    print(s)
    logfile.write(s + '\n')

    scn.checkpoint_save(unet, optimizer, exp_name, 'unet', epoch, use_cuda=use_cuda, save_frequency=2)
    # validate
    if epoch % 2 == 0:
        with torch.no_grad():
            unet.eval()
            scn.forward_pass_multiplyAdd_count = 0
            scn.forward_pass_hidden_states = 0
            save = False
            print("\nEvaluation")
            start = time.time()
            for rep in range(1, 1 + data.val_reps):
                all_pred = np.array([], dtype=int)
                valLabels = np.array([], dtype=int)
                num_batches = len(val_data_loader)
                for i, batch in enumerate(val_data_loader):
                    print(">>>Processing batch: {}/{}".format(i + 1, num_batches))