Beispiel #1
0
def plot_joints_from_heatmaps(heatmaps, data=None, title='', fig=None, linewidth=2):
    if fig is None:
        fig = plt.figure()
    joints_colorspace = conv.heatmaps_to_joints_colorspace(heatmaps)
    fig = plot_joints(joints_colorspace, fig=fig, linewidth=linewidth)
    if not data is None:
        data_img_RGB = conv.numpy_to_plottable_rgb(data)
        fig = plot_img_RGB(data_img_RGB, fig=fig, title=title)
    return fig
Beispiel #2
0
def plot_halnet_joints_from_heatmaps_crop(halnet_main_out, img_numpy, filenamebase, plot=True):
    labels_colorspace = conv.heatmaps_to_joints_colorspace(halnet_main_out)
    data_crop, crop_coords, labels_heatmaps, labels_colorspace = \
        converter.crop_image_get_labels(img_numpy, labels_colorspace, range(21))
    if plot:
        fig = visualize.create_fig()
        visualize.plot_image(data_crop, title=filenamebase, fig=fig)
        visualize.plot_joints_from_colorspace(labels_colorspace, title=filenamebase, fig=fig, data=data_crop)
        visualize.title('HALNet (joints from heatmaps - cropped): ' + filenamebase)
        visualize.show()
    return data_crop
    print('\tHALNet:')
    start = time.time()
    halnet_input = conv.data_to_batch(img_data)
    print_time('\t\tHALNet image convertion: ', time.time() - start)

    start = time.time()
    output_halnet = halnet(halnet_input)
    print_time('\t\tHALNet pass: '******'\t\t\tHALNet hand root:\t{}'.format(halnet_handroot))

    # get halnet joints in colorspace from heatmaps
    halnet_joints_colorspace = conv.heatmaps_to_joints_colorspace(
        halnet_main_out)

    print('\tHALNet joint pixel loss:')
    num_valid_loss = 0
    loss_halnet_joints = 0
    halnet_out_fingertips = np.zeros((5, 2))
    for i in range(NUM_JOINTS):
        idx = get_label_index(dataset_name, i)
        if np.sum(img_labels_2D[i, :]) > 0:
            curr_loss = np.linalg.norm(img_labels_2D[i, :] -
                                       halnet_joints_colorspace[idx, :])
            loss_halnet_joints += curr_loss
            num_valid_loss += 1
            losses_halnet_per_joints[example_ix, i] = curr_loss
            print('\t\t\tJoint {}: {}\t{} : {}'.format(
                i, img_labels_2D[i, :], halnet_joints_colorspace[idx, :],
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True):
    curr_epoch_iter = 1
    for batch_idx, (data, target) in enumerate(valid_loader):
        control_vars['batch_idx'] = batch_idx
        if batch_idx < control_vars['iter_size']:
            print_verbose("\rPerforming first iteration; current mini-batch: " +
                          str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True)
        # start time counter
        start = time.time()
        # get data and targetas cuda variables
        target_heatmaps, target_joints, target_joints_z = target
        data, target_heatmaps = Variable(data), Variable(target_heatmaps)
        if valid_vars['use_cuda']:
            data = data.cuda()
            target_heatmaps = target_heatmaps.cuda()
        # visualize if debugging
        # get model output
        output = model(data)
        # accumulate loss for sub-mini-batch
        if valid_vars['cross_entropy']:
            loss_func = my_losses.cross_entropy_loss_p_logq
        else:
            loss_func = my_losses.euclidean_loss
        loss = my_losses.calculate_loss_HALNet(loss_func,
            output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1,
            model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3,
            model.WEIGHT_LOSS_MAIN, control_vars['iter_size'])

        if DEBUG_VISUALLY:
            for i in range(control_vars['max_mem_batch']):
                filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i
                filenamebase = valid_loader.dataset.get_filenamebase(filenamebase_idx)
                fig = visualize.create_fig()
                #visualize.plot_joints_from_heatmaps(output[3][i].data.numpy(), fig=fig,
                #                                    title=filenamebase, data=data[i].data.numpy())
                #visualize.plot_image_and_heatmap(output[3][i][8].data.numpy(),
                #                                 data=data[i].data.numpy(),
                #                                 title=filenamebase)
                #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_heatmap')

                labels_colorspace = conv.heatmaps_to_joints_colorspace(output[3][i].data.numpy())
                data_crop, crop_coords, labels_heatmaps, labels_colorspace = \
                    converter.crop_image_get_labels(data[i].data.numpy(), labels_colorspace, range(21))
                visualize.plot_image(data_crop, title=filenamebase, fig=fig)
                visualize.plot_joints_from_colorspace(labels_colorspace, title=filenamebase, fig=fig, data=data_crop)
                #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_crop')
                visualize.show()

        #loss.backward()
        valid_vars['total_loss'] += loss
        # accumulate pixel dist loss for sub-mini-batch
        valid_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple(
            valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size'])
        if valid_vars['cross_entropy']:
            valid_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple(
                valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size'])
        else:
            valid_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs)
        # get boolean variable stating whether a mini-batch has been completed
        minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0
        if minibatch_completed:
            # append total loss
            valid_vars['losses'].append(valid_vars['total_loss'].item())
            # erase total loss
            total_loss = valid_vars['total_loss'].item()
            valid_vars['total_loss'] = 0
            # append dist loss
            valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss'])
            # erase pixel dist loss
            valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs)
            # append dist loss of sample from output
            valid_vars['pixel_losses_sample'].append(valid_vars['total_pixel_loss_sample'])
            # erase dist loss of sample from output
            valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs)
            # check if loss is better
            if valid_vars['losses'][-1] < valid_vars['best_loss']:
                valid_vars['best_loss'] = valid_vars['losses'][-1]
                #print_verbose("  This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose)
            # log checkpoint
            if control_vars['curr_iter'] % control_vars['log_interval'] == 0:
                trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars)
                model_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'control_vars': control_vars,
                    'train_vars': valid_vars,
                }
                trainer.save_checkpoint(model_dict,
                                        filename=valid_vars['checkpoint_filenamebase'] +
                                                 str(control_vars['num_iter']) + '.pth.tar')
            # print time lapse
            prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\
                     str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\
                     '(' + str(control_vars['iter_size']) + ')' + '/' +\
                     str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\
                     '(' + str(control_vars['batch_size']) + ')' +\
                     ' - log every ' + str(control_vars['log_interval']) + ' iter): '
            control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start,
                                                            control_vars['curr_iter'], control_vars['num_iter'],
                                                            prefix=prefix)

            control_vars['curr_iter'] += 1
            control_vars['start_iter'] = control_vars['curr_iter'] + 1
            control_vars['curr_epoch_iter'] += 1


    return valid_vars, control_vars
Beispiel #5
0
    start = time.time()
    output_halnet = halnet(conv.data_to_batch(data))
    print_time('HALNet pass: '******'Handroot (colorspace):\t{}'.format(handroot_colorspace))
    print('Handroot (colorspace), z:\t{}'.format(
        img_numpy[3, handroot_colorspace[0], handroot_colorspace[1]]))
    print('Handroot (depthspace):\t{}'.format(handroot))
    labels_colorspace = conv.heatmaps_to_joints_colorspace(halnet_main_out)

    data_crop, _, _, _ = io_image.crop_image_get_labels(
        img_numpy, labels_colorspace, range(21))
    batch_jornet = conv.data_to_batch(data_crop)
    print_time('JORNet image conversion: ', time.time() - start)

    start = time.time()
    output_jornet = jornet(batch_jornet)
    print_time('JORNet pass: ', time.time() - start)

    start = time.time()
    jornet_joints_mainout = output_jornet[7][0].data.cpu().numpy()

    jornet_joints_global = conv.jornet_local_to_global_joints(
        jornet_joints_mainout, handroot)