Ejemplo n.º 1
0
def get_dmap_rewards(control_points,
                     num_cp,
                     num_beziers,
                     im,
                     grid,
                     distance='l2'):
    """
    control_points.shape = (num_cp*max_beziers, batch_size, 2)
    num_cp = scalar
    num_beziers.shape = (batch_size)
    im.shape = (batch_size, 1, 64, 64)
    grid.shape = (1, 1, 64, 64, 2)
    """
    batch_size = im.shape[0]
    pred_seq = torch.empty((batch_size, 0, 1, 1, 2), device=im.device)
    dmap_rewards = torch.empty((batch_size, 0), device=im.device)

    to_end = True
    i = 0
    while to_end:
        not_finished = num_beziers > i
        to_end = torch.sum(not_finished)

        num_cps = torch.zeros_like(num_beziers)
        num_cps[not_finished] = num_cp
        partial_pred_seq = bezier(
            control_points[num_cp * i:num_cp * i + num_cp],
            num_cps,
            torch.linspace(0, 1, 150, device=num_cps.device).unsqueeze(0),
            device=num_cps.device)
        pred_seq = torch.cat(
            (pred_seq, partial_pred_seq.unsqueeze(-2).unsqueeze(-2)), dim=1)

        # Calculamos el reward obtenido después de dibujar esta curva
        new_dmap = torch.sqrt(torch.sum((grid - pred_seq)**2, dim=-1))
        new_dmap, _ = torch.min(new_dmap, dim=1)
        if distance == 'quadratic':
            new_dmap = new_dmap * new_dmap
        elif distance == 'exp':
            new_dmap = torch.exp(new_dmap)
        new_rewards = -torch.sum(
            new_dmap * im[:, 0] / torch.sum(im[:, 0], dim=(1, 2)), dim=(1, 2))
        dmap_rewards = torch.cat((dmap_rewards, new_rewards), dim=1)

        i += 1

    return dmap_rewards
Ejemplo n.º 2
0
def loss_function(control_points, im, distance_im, covariance, probabilistic_map_generator, grid): #, repulsion_coef=0.1, dist_thresh=4.5, second_term=True):
    batch_size = control_points.shape[2]
    num_cp = control_points.shape[1]

    probability_map = torch.empty((0, batch_size, 64, 64), dtype=torch.float32, device=control_points.device)
    pred_seq = torch.empty((batch_size, 0, 1, 1, 2), dtype=torch.float32, device=control_points.device)

    #curvature_penalizations = torch.empty((batch_size, 0), dtype=torch.float32, device=control_points.device)

    num_cps = num_cp*torch.ones(batch_size, dtype=torch.long, device=control_points.device)

    for bezier_cp in control_points:
        # Calculamos el mapa probabilistico de esta curva y lo concatenamos con los de las demás curvas
        partial_probability_map = probabilistic_map_generator(bezier_cp, num_cps, covariance)
        probability_map = torch.cat((probability_map, partial_probability_map.unsqueeze(0)), dim=0)

        # Calculamos la secuencia de puntos de esta curva y la concatenamos con las de las demás curvas
        partial_pred_seq = bezier(bezier_cp, num_cps,
                                  torch.linspace(0, 1, 150, device=num_cps.device).unsqueeze(0), device=num_cps.device)
        pred_seq = torch.cat((pred_seq, partial_pred_seq.unsqueeze(-2).unsqueeze(-2)), dim=1)

        # Calculamos la curvatura media o máxima de las curvas predichas y la almacenamos
        #new_curvatures = curvature(partial_pred_seq, mode='max')
        #curvature_penalizations = torch.cat((curvature_penalizations, new_curvatures.unsqueeze(1)), dim=1)

    # Calculamos los mapas probabilisticos y de distancias del conjunto de curvas
    pmap, _ = torch.max(probability_map, dim=0)
    dmap = torch.sqrt(torch.sum((grid - pred_seq) ** 2, dim=-1))
    dmap, _ = torch.min(dmap, dim=1)

    # Calculamos la penalización por curvatura
    #curvature_penalizations = torch.mean(curvature_penalizations)

    # Calculamos la fake chamfer_distance
    fake_chamfer = torch.sum(im[:, 0]*dmap/torch.sum(im[:, 0], dim=(1, 2)).view(-1, 1, 1)+pmap*distance_im[:, 0]/torch.sum(pmap, dim=(1, 2)).view(-1, 1, 1))/batch_size #+ curv_pen_coef*curvature_penalizations

    """repulsion_penalty = 0
    if repulsion_coef > 0:
        repulsion_penalty = repulsion(control_points, dist_thresh=dist_thresh, second_term=second_term)"""

    return fake_chamfer# + repulsion_coef*repulsion_penalty
Ejemplo n.º 3
0
def new_loss(pred_cp, groundtruth_im, groundtruth_seq, grid):
    """
    pred_cp.shape = (num_beziers, num_cp, batch_size, 2)
    groundtruth_seq.shape = (bs, max_N, 2)
    groundtruth_im.shape = (bs, 1, 64, 64)
    grid.shape = (1, 1, 64, 64, 2)

    As the training images will have different number of points belonging to the curves, max_N will be the highest number
    of points among all the dataset images. That means that for almost all the images we will need to padd the tensor
    groundtruth_seq (to fill the max_N coordinates), we will do the padding with a constant whose distance to any point
    of the canvas is high enough to don't be considered in the computation of the chamfer distance.
    """
    batch_size = pred_cp.shape[2]
    num_cp = pred_cp.shape[1]
    groundtruth_im = groundtruth_im[:, 0]

    # Computation of the sequence of coordinates of the predicted bézier curves
    num_cps = num_cp * torch.ones(batch_size, dtype=torch.long, device=pred_cp.device)
    pred_seq = torch.empty((batch_size, 0, 1, 2), dtype=torch.float32, device=pred_cp.device)
    for bezier_cp in pred_cp:
        partial_pred_seq = bezier(bezier_cp, num_cps,
                                  torch.linspace(0, 1, 150, device=num_cps.device).unsqueeze(0), device=num_cps.device)
        pred_seq = torch.cat((pred_seq, partial_pred_seq.unsqueeze(-2)), dim=1)


    """Computation of the first side of the chamfer distance sum(pred_im*dmap(original_im))/sum(pred_im)"""
    # pred_seq.shape = (bs, N, 1, 2)
    groundtruth_seq = groundtruth_seq.unsqueeze(1) #groundtruth_seq.shape = (bs, 1, max_N, 2)
    # Computation of the fake chamfer distance
    temp = torch.sqrt(torch.sum((pred_seq - groundtruth_seq) ** 2, dim=-1)) #temp.shape(bs, N, max_N)
    fact1 = torch.mean(torch.min(temp, dim=-1), dim=-1)

    """Computation of the second side of the chamfer distance sum(dmap(pred_im)*original_im)/sum(original_im)"""
    # Computation of the distance map
    dmap = torch.sqrt(torch.sum((grid - pred_seq.unsqueeze(-2)) ** 2, dim=-1))
    dmap, _ = torch.min(dmap, dim=1)
    fact2 = torch.sum(groundtruth_im * dmap, dim=(1, 2)) / torch.sum(groundtruth_im, dim=(1, 2))

    return torch.mean(fact1 + fact2)
Ejemplo n.º 4
0
def get_chamfer_rewards(control_points, num_cp, num_beziers, im, distance_im,
                        covariance, probabilistic_map_generator, grid):
    """
    control_points.shape = (num_cp*max_beziers, batch_size, 2)
    num_cp = scalar
    num_beziers.shape = (batch_size)
    actual_covariances.shape = (num_cp, batch_size, 2, 2)
    im.shape = (batch_size, 1, 64, 64)
    loss_im.shape = (batch_size 1, 64, 64)
    """
    batch_size = im.shape[0]

    probability_map = torch.empty((0, batch_size, 64, 64),
                                  dtype=torch.float32,
                                  device=im.device)
    pred_seq = torch.empty((batch_size, 0, 1, 1, 2),
                           dtype=torch.float32,
                           device=im.device)
    chamfer_rewards = torch.empty((0, batch_size),
                                  dtype=torch.float32,
                                  device=im.device)

    i = 0
    not_finished = num_beziers > i
    to_end = torch.sum(not_finished)
    while to_end:
        num_cps = torch.zeros_like(num_beziers)
        num_cps[not_finished] = num_cp

        partial_probability_map = probabilistic_map_generator(
            control_points[num_cp * i:num_cp * i + num_cp], num_cps,
            covariance)
        probability_map = torch.cat(
            (probability_map, partial_probability_map.unsqueeze(0)), dim=0)

        partial_pred_seq = bezier(
            control_points[num_cp * i:num_cp * i + num_cp],
            num_cps,
            torch.linspace(0, 1, 150, device=num_cps.device).unsqueeze(0),
            device=num_cps.device)
        pred_seq = torch.cat(
            (pred_seq, partial_pred_seq.unsqueeze(-2).unsqueeze(-2)), dim=1)

        #Calculamos el reward obtenido después de dibujar esta curva
        pmap, _ = torch.max(probability_map, dim=0)
        dmap = torch.sqrt(torch.sum((grid - pred_seq)**2, dim=-1))
        dmap, _ = torch.min(dmap, dim=1)

        new_rewards = -torch.sum(
            im[:, 0] * dmap / torch.sum(im[:, 0], dim=(1, 2)).view(-1, 1, 1) +
            pmap * distance_im[:, 0] /
            torch.sum(pmap, dim=(1, 2)).view(-1, 1, 1),
            dim=(1, 2))
        chamfer_rewards = torch.cat(
            (chamfer_rewards, new_rewards.unsqueeze(0)), dim=0)

        i += 1
        not_finished = num_beziers > i
        to_end = torch.sum(not_finished)

    return chamfer_rewards
Ejemplo n.º 5
0
def train_one_bezier_transformer(model, dataset, batch_size, num_epochs, optimizer, loss_mode,
                                 num_experiment, cp_variance, var_drop, epochs_drop, min_variance, penalization_coef,
                                 lr=1e-4, cuda=True, debug=True):
    # torch.autograd.set_detect_anomaly(True)
    print("\n\nTHE TRAINING BEGINS")
    print("MultiBezier Experiment #{} ---> loss={} distance_loss={} num_cp={} max_beziers={} batch_size={} num_epochs={} learning_rate={} pen_coef={}".format(
        num_experiment, loss_mode[0], loss_mode[1], model.num_cp, model.max_beziers, batch_size, num_epochs, lr, penalization_coef))

    # basedir = "/data1slow/users/asuso/trans_bezier"
    basedir = "/home/asuso/PycharmProjects/trans_bezier"

    # Iniciamos una variable en la que guardaremos la mejor loss obtenida en validation
    best_loss = float('inf')

    # Inicializamos el writer de tensorboard, y las variables que usaremos para
    # la recopilación de datos.
    cummulative_loss = 0
    if debug:
        # Tensorboard writter
        writer = SummaryWriter(basedir + "/graphics/ProbabilisticBezierEncoder/MultiBezierModels/FixedCP/"+str(model.num_cp)+"CP_maxBeziers"+str(model.max_beziers)+"loss"+str(loss_mode[0]))
        counter = 0

    # Obtenemos las imagenes del dataset
    images = dataset

    assert loss_mode[0] in ['pmap', 'dmap', 'chamfer']
    assert loss_mode[1] in ['l2', 'quadratic', 'exp']
    # En caso de que la loss seleccionada sea la del mapa probabilístico, hacemos los preparativos necesarios
    if loss_mode[0] == 'pmap':
        # Inicializamos el generador de mapas probabilisticos y la matriz de covariancias
        probabilistic_map_generator = ProbabilisticMap((model.image_size, model.image_size, 50))
        cp_covariance = torch.tensor([ [[1, 0], [0, 1]] for i in range(model.num_cp)], dtype=torch.float32)
        cp_covariances = torch.empty((model.num_cp, batch_size, 2, 2))
        for i in range(batch_size):
            cp_covariances[:, i, :, :] = cp_covariance

        # Obtenemos las imagenes del groundtruth con el factor de penalización añadido
        loss_images = generate_loss_images(images, weight=penalization_coef, distance=loss_mode[1])
        grid = None
        distance_im = None
        if cuda:
            images = images.cuda()
            loss_images = loss_images.cuda()

            probabilistic_map_generator = probabilistic_map_generator.cuda()
            cp_covariances = cp_covariances.cuda()
            model = model.cuda()

    # En caso de que la loss seleccionada sea la del mapa de distancias, hacemos los preparativos necesarios
    if loss_mode[0] == 'dmap':
        grid = torch.empty((1, 1, images.shape[2], images.shape[3], 2), dtype=torch.float32)
        for i in range(images.shape[2]):
            grid[0, 0, i, :, 0] = i
            grid[0, 0, :, i, 1] = i
        loss_im = None
        distance_im = None
        actual_covariances = None
        probabilistic_map_generator = None
        if cuda:
            images = images.cuda()

            grid = grid.cuda()
            model = model.cuda()

    # En caso de que la loss seleccionada sea la chamfer_loss, hacemos los preparativos necesarios
    if loss_mode[0] == 'chamfer':
        # Inicializamos el generador de mapas probabilisticos y la matriz de covariancias
        probabilistic_map_generator = ProbabilisticMap((model.image_size, model.image_size, 50))
        cp_covariance = torch.tensor([ [[1, 0], [0, 1]] for i in range(model.num_cp)], dtype=torch.float32)
        actual_covariances = torch.empty((model.num_cp, batch_size, 2, 2))
        for i in range(batch_size):
            actual_covariances[:, i, :, :] = cp_covariance
        # Obtenemos el grid
        grid = torch.empty((1, 1, images.shape[2], images.shape[2], 2), dtype=torch.float32)
        for i in range(images.shape[2]):
            grid[0, 0, i, :, 0] = i
            grid[0, 0, :, i, 1] = i
        # Obtenemos las distance_images
        distance_images = generate_distance_images(images)
        loss_im = None
        if cuda:
            images = images.cuda()
            distance_images = distance_images.cuda()

            probabilistic_map_generator = probabilistic_map_generator.cuda()
            actual_covariances = actual_covariances.cuda()
            grid = grid.cuda()
            model = model.cuda()

    # Particionamos el dataset en training y validation
    # images.shape=(N, 1, 64, 64)
    im_training = images[:40000]
    im_validation = images[40000:]
    if loss_mode[0] == 'pmap':
        loss_im_training = loss_images[:40000]
        loss_im_validation = loss_images[40000:]
    if loss_mode[0] == 'chamfer':
        distance_im_training = distance_images[:40000]
        distance_im_validation = distance_images[40000:]

    # Definimos el optimizer
    optimizer = optimizer(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=10**(-0.5), patience=8, min_lr=1e-8)

    for epoch in range(num_epochs):
        t0 = time.time()
        print("Beginning epoch number", epoch+1)
        if loss_mode[0] == 'pmap':
            actual_covariances = cp_covariances * step_decay(cp_variance, epoch, var_drop, epochs_drop, min_variance).to(cp_covariances.device)
        for i in range(0, len(im_training)-batch_size+1, batch_size):
            # Obtenemos el batch
            im = im_training[i:i+batch_size]#.cuda()
            if loss_mode[0] == 'pmap':
                loss_im = loss_im_training[i:i+batch_size]#.cuda()
            if loss_mode[0] == 'chamfer':
                distance_im = distance_im_training[i:i+batch_size]#.cuda()

            # Ejecutamos el modelo sobre el batch
            control_points, probabilities = model(im)

            # Calculamos la loss
            loss = loss_function(epoch, control_points, model.max_beziers+torch.zeros(batch_size, dtype=torch.long, device=control_points.device), probabilities, model.num_cp,
                                 im, distance_im, loss_im, grid, actual_covariances, probabilistic_map_generator, loss_type=loss_mode[0], distance='l2', gamma=0.9)

            # Realizamos backpropagation y un paso de descenso del gradiente
            loss.backward()
            optimizer.step()
            model.zero_grad()

            # Recopilación de datos para tensorboard
            k = int(int(40000/(batch_size*5))*batch_size + 1)
            if debug:
                cummulative_loss += loss.detach()
                if i%k == k-1:
                    writer.add_scalar("Training/loss", cummulative_loss/k, counter)
                    counter += 1
                    cummulative_loss = 0


        """
        Al completar cada época, probamos el modelo sobre el conjunto de validation. En concreto:
           - Calcularemos la loss del modelo sobre el conjunto de validación
           - Realizaremos 500 predicciones sobre imagenes del conjunto de validación. Generaremos una imagen a partir de la parametrización de la curva de bezier obtenida.
             Calcularemos las metricas IoU, chamfer_distance, y differentiable_chamfer_distance (probabilistic_map)
             asociadas a estas prediciones (comparandolas con el ground truth).
        """
        model.eval()
        with torch.no_grad():
            cummulative_loss = 0
            for j in range(0, len(im_validation)-batch_size+1, batch_size):
                # Obtenemos el batch
                im = im_validation[j:j+batch_size]#.cuda()
                if loss_mode[0] == 'pmap':
                    loss_im = loss_im_validation[j:j + batch_size]#.cuda()
                if loss_mode[0] == 'chamfer':
                    distance_im = distance_im_validation[j:j + batch_size]#.cuda()

                # Ejecutamos el modelo sobre el batch
                control_points, probabilities = model(im)

                # Calculamos la loss
                loss = loss_function(epoch, control_points, model.max_beziers+torch.zeros(batch_size, dtype=torch.long, device=control_points.device),
                                     probabilities, model.num_cp, im, distance_im, loss_im, grid, actual_covariances, probabilistic_map_generator, loss_type=loss_mode[0], distance='l2', gamma=0.9)
                cummulative_loss += loss.detach()

            # Aplicamos el learning rate scheduler
            scheduler.step(cummulative_loss)

            # Recopilamos los datos para tensorboard
            if debug:
                writer.add_scalar("Validation/loss", cummulative_loss/(j/batch_size+1), counter)

            # Si la loss obtenida es la mas baja hasta ahora, nos guardamos los pesos del modelo
            if cummulative_loss < best_loss:
                print("El modelo ha mejorado!! Nueva loss={}".format(cummulative_loss/(j/batch_size+1)))
                best_loss = cummulative_loss
                torch.save(model.state_dict(), basedir+"/state_dicts/ProbabilisticBezierEncoder/MultiBezierModels/FixedCP/"+str(model.num_cp)+"CP_maxBeziers"+str(model.max_beziers)+"loss"+str(loss_mode[0]))
            cummulative_loss = 0


            # Representación grafica del modo forward
            target_images = im_validation[0:200:20].cuda()
            forwarded_images = torch.zeros_like(target_images)
            forwarded_cp, _ = model(target_images)

            # Renderizamos las imagenes forward
            for i in range(model.max_beziers):
                num_cps = model.num_cp * torch.ones(10, dtype=torch.long, device=forwarded_cp.device)
                im_seq = bezier(forwarded_cp[model.num_cp * i: model.num_cp * (i + 1)], num_cps,
                                torch.linspace(0, 1, 150, device=control_points.device).unsqueeze(0), device='cuda')
                im_seq = torch.round(im_seq).long()
                for j in range(10):
                    forwarded_images[j, 0, im_seq[j, :, 0], im_seq[j, :, 1]] = 1

            img_grid = torchvision.utils.make_grid(forwarded_images)
            writer.add_image('forwarded_images', img_grid)

            
            # Iniciamos la evaluación del modo "predicción"
            if epoch > 60:
                iou_value = 0
                chamfer_value = 0

                # Inicialmente, predeciremos 10 imagenes que almacenaremos en tensorboard
                target_images = im_validation[0:200:20].cuda()
                predicted_images = torch.zeros_like(target_images)
                control_points, num_beziers = model.predict(target_images)

                # Renderizamos las imagenes predichas
                i = 0
                not_finished = num_beziers > i
                to_end = torch.sum(not_finished)
                while to_end:
                    num_cps = model.num_cp * torch.ones_like(num_beziers[not_finished])
                    im_seq = bezier(control_points[model.num_cp*i: model.num_cp*(i+1), not_finished], num_cps, torch.linspace(0, 1, 150, device=control_points.device).unsqueeze(0), device='cuda')
                    im_seq = torch.round(im_seq).long()
                    k = 0
                    for j in range(10):
                        if not_finished[j]:
                            predicted_images[j, 0, im_seq[k, :, 0], im_seq[k, :, 1]] = 1
                            k += 1
                    i += 1
                    not_finished = num_beziers > i
                    to_end = torch.sum(not_finished)

                # Guardamos estas primeras 10 imagenes en tensorboard
                img_grid = torchvision.utils.make_grid(target_images)
                writer.add_image('target_images', img_grid)
                img_grid = torchvision.utils.make_grid(predicted_images)
                writer.add_image('predicted_images', img_grid)

                # Calculamos metricas
                iou_value += intersection_over_union(predicted_images, target_images)
                chamfer_value += np.sum(
                    chamfer_distance(predicted_images[:, 0].cpu().numpy(), target_images[:, 0].cpu().numpy()))


                # Finalmente, predecimos 490 imagenes mas para calcular IoU y chamfer_distance
                idxs = [200, 1600, 3000, 4400, 5800, 7200, 8600, 10000]
                for i in range(7):
                    target_images = im_validation[idxs[i]:idxs[i+1]:20].cuda()
                    predicted_images = torch.zeros_like(target_images)
                    control_points, num_beziers = model.predict(target_images)

                    # Renderizamos las imagenes predichas
                    i = 0
                    not_finished = num_beziers > i
                    to_end = torch.sum(not_finished)
                    while to_end:
                        num_cps = model.num_cp * torch.ones_like(num_beziers[not_finished])
                        im_seq = bezier(control_points[model.num_cp * i: model.num_cp * (i + 1), not_finished], num_cps,
                                    torch.linspace(0, 1, 150, device=control_points.device).unsqueeze(0), device='cuda')
                        im_seq = torch.round(im_seq).long()
                        k = 0
                        for j in range(70):
                            if not_finished[j]:
                                predicted_images[j, 0, im_seq[k, :, 0], im_seq[k, :, 1]] = 1
                                k += 1
                        i += 1
                        not_finished = num_beziers > i
                        to_end = torch.sum(not_finished)

                    # Calculamos metricas
                    iou_value += intersection_over_union(predicted_images, target_images)
                    chamfer_value += np.sum(
                        chamfer_distance(predicted_images[:, 0].cpu().numpy(), target_images[:, 0].cpu().numpy()))

                # Guardamos los resultados en tensorboard
                writer.add_scalar("Prediction/IoU", iou_value / 500, counter)
                writer.add_scalar("Prediction/Chamfer_distance", chamfer_value / 500, counter)

        # Volvemos al modo train para la siguiente epoca
        model.train()
        print("Tiempo por epoca de", time.time()-t0)
Ejemplo n.º 6
0
def train_one_bezier_transformer(
        model,
        dataset,
        batch_size,
        num_epochs,
        optimizer,
        num_experiment,
        lr=1e-4,
        loss_type='probabilistic',
        dataset_name="MNIST",
        cuda=True,
        debug=True):  # rep_coef=0.1, dist_thresh=4.5, second_term=True
    # torch.autograd.set_detect_anomaly(True)
    print("\n\nTHE TRAINING BEGINS")
    print(
        "MultiBezier Experiment #{} ---> num_cp={} max_beziers={} batch_size={} num_epochs={} learning_rate={}"
        .format(num_experiment, model.num_cp, model.max_beziers, batch_size,
                num_epochs, lr))

    # basedir = "/data1slow/users/asuso/trans_bezier"
    basedir = "/home/asuso/PycharmProjects/trans_bezier"

    # Iniciamos una variable en la que guardaremos la mejor loss obtenida en validation
    best_loss = float('inf')

    # Inicializamos el writer de tensorboard, y las variables que usaremos para
    # la recopilación de datos.
    if debug:
        # Tensorboard writter
        writer = SummaryWriter(
            basedir +
            "/graphics/ProbabilisticBezierEncoder/MultiBezierModels/ParallelVersion/"
            + str(dataset_name) + "_" + str(loss_type) + "_" +
            str(model.num_cp) + "CP_maxBeziers" + str(model.max_beziers))
        counter = 0

    # Obtenemos las imagenes del dataset
    images = dataset

    if loss_type == "probabilistic":
        # Inicializamos el generador de mapas probabilisticos y la matriz de covariancias
        probabilistic_map_generator = ProbabilisticMap(
            (model.image_size, model.image_size, 50))
        cp_covariance = torch.tensor([[[1, 0], [0, 1]]
                                      for i in range(model.num_cp)],
                                     dtype=torch.float32)
        covariances = torch.empty((model.num_cp, batch_size, 2, 2))
        for i in range(batch_size):
            covariances[:, i, :, :] = cp_covariance
        # Obtenemos las distance_images
        distance_images = generate_distance_images(images)

        if cuda:
            distance_images = distance_images.cuda()
            probabilistic_map_generator = probabilistic_map_generator.cuda()
            covariances = covariances.cuda()

        distance_im_training = distance_images[:int(0.83 * images.shape[0])]
        distance_im_validation = distance_images[int(0.83 * images.shape[0]):]

    else:
        # Obtenemos la groundtruth_seq del dataset (bs, max_N, 2)
        im_size = images.shape[-1]
        groundtruth_sequences = torch.empty(
            (images.shape[0], im_size * im_size, 2), dtype=torch.long)
        max_coords = 0
        for i, im in enumerate(images):
            new_coords = torch.nonzero(im[0])
            groundtruth_sequences[i, :new_coords.shape[0]] = new_coords
            if new_coords.shape[0] > max_coords:
                max_coords = new_coords.shape[0]
        groundtruth_sequences = groundtruth_sequences[:, :max_coords]

        if cuda:
            groundtruth_sequences = groundtruth_sequences.cuda()

        groundtruth_seq_training = groundtruth_sequences[:int(0.83 *
                                                              images.shape[0])]
        groundtruth_seq_validation = groundtruth_sequences[int(0.83 *
                                                               images.shape[0]
                                                               ):]

    # Obtenemos el grid
    grid = torch.empty((1, 1, images.shape[2], images.shape[2], 2),
                       dtype=torch.float32)
    for i in range(images.shape[2]):
        grid[0, 0, i, :, 0] = i
        grid[0, 0, :, i, 1] = i

    if cuda:
        images = images.cuda()
        grid = grid.cuda()
        model = model.cuda()

    # Particionamos el dataset en training y validation
    # images.shape=(N, 1, 64, 64)
    im_training = images[:int(0.83 * images.shape[0])]
    im_validation = images[int(0.83 * images.shape[0]):]

    # Definimos el optimizer
    optimizer = optimizer(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=10**(-0.5),
                                  patience=8,
                                  min_lr=1e-8)

    for epoch in range(num_epochs):
        t0 = time.time()
        print("Beginning epoch number", epoch + 1)
        cummulative_loss = 0
        for i in range(0, len(im_training) - batch_size + 1, batch_size):
            # Obtenemos el batch
            im = im_training[i:i + batch_size]  #.cuda()

            # Ejecutamos el modelo sobre el batch
            control_points = model(im)

            if loss_type == "probabilistic":
                distance_im = distance_im_training[i:i + batch_size]  # .cuda()
                # Calculamos la loss
                loss = loss_function(
                    control_points, im, distance_im, covariances,
                    probabilistic_map_generator, grid
                )  #repulsion_coef=rep_coef, dist_thresh=dist_thresh, second_term=second_term
            else:
                groundtruth_seq = groundtruth_seq_training[i:i + batch_size]
                loss = new_loss(control_points, im, groundtruth_seq, grid)

            # Realizamos backpropagation y un paso de descenso del gradiente
            loss.backward()
            optimizer.step()
            model.zero_grad()

            if debug:
                cummulative_loss += loss.detach()

        if debug:
            writer.add_scalar("Training/loss",
                              cummulative_loss / (i / batch_size), counter)
        """
        Al completar cada época, probamos el modelo sobre el conjunto de validation. En concreto:
           - Calcularemos la loss del modelo sobre el conjunto de validación
           - Realizaremos 500 predicciones sobre imagenes del conjunto de validación. Generaremos una imagen a partir de la parametrización de la curva de bezier obtenida.
             Calcularemos las metricas IoU, chamfer_distance, y differentiable_chamfer_distance (probabilistic_map)
             asociadas a estas prediciones (comparandolas con el ground truth).
        """
        model.eval()
        with torch.no_grad():
            cummulative_loss = 0
            for j in range(0, len(im_validation) - batch_size + 1, batch_size):
                # Obtenemos el batch
                im = im_validation[i:i + batch_size]  # .cuda()

                # Ejecutamos el modelo sobre el batch
                control_points = model(im)

                if loss_type == "probabilistic":
                    distance_im = distance_im_validation[j:j +
                                                         batch_size]  # .cuda()
                    # Calculamos la loss
                    loss = loss_function(
                        control_points, im, distance_im, covariances,
                        probabilistic_map_generator, grid
                    )  # repulsion_coef=rep_coef, dist_thresh=dist_thresh, second_term=second_term
                else:
                    groundtruth_seq = groundtruth_seq_validation[j:j +
                                                                 batch_size]
                    loss = new_loss(control_points, im, groundtruth_seq, grid)
                cummulative_loss += loss.detach()

            # Aplicamos el learning rate scheduler
            scheduler.step(cummulative_loss)

            # Recopilamos los datos para tensorboard
            if debug:
                writer.add_scalar("Validation/loss", cummulative_loss / j,
                                  counter)

            # Si la loss obtenida es la mas baja hasta ahora, nos guardamos los pesos del modelo
            if cummulative_loss < best_loss:
                print("El modelo ha mejorado!! Nueva loss={}".format(
                    cummulative_loss / j))
                best_loss = cummulative_loss
                torch.save(
                    model.state_dict(), basedir +
                    "/state_dicts/ProbabilisticBezierEncoder/MultiBezierModels/ParallelVersion/"
                    + str(dataset_name) + "_" + str(loss_type) + "_" +
                    str(model.num_cp) + "CP_maxBeziers" +
                    str(model.max_beziers))

            # Iniciamos la evaluación del modo "predicción"
            iou_value = 0
            chamfer_value = 0

            # Inicialmente, predeciremos 10 imagenes que almacenaremos en tensorboard
            target_images = im_validation[0:200:20]  #.cuda()
            predicted_images = torch.zeros_like(target_images)
            control_points = model(target_images)

            # Renderizamos las imagenes predichas
            for bezier_cp in control_points:
                # Calculamos la secuencia de puntos de esta curva
                num_cps = model.num_cp * torch.ones(
                    10, dtype=torch.long, device=bezier_cp.device)
                im_seq = bezier(
                    bezier_cp,
                    num_cps,
                    torch.linspace(0, 1, 150,
                                   device=num_cps.device).unsqueeze(0),
                    device=num_cps.device)
                im_seq = torch.round(im_seq).long()
                for j in range(10):
                    predicted_images[j, 0, im_seq[j, :, 0], im_seq[j, :,
                                                                   1]] = 1

            # Guardamos estas primeras 10 imagenes en tensorboard
            img_grid = torchvision.utils.make_grid(target_images)
            writer.add_image('target_images', img_grid)
            img_grid = torchvision.utils.make_grid(predicted_images)
            writer.add_image('predicted_images', img_grid)

            # Calculamos metricas
            iou_value += intersection_over_union(predicted_images,
                                                 target_images)
            chamfer_value += np.sum(
                chamfer_distance(predicted_images[:, 0].cpu().numpy(),
                                 target_images[:, 0].cpu().numpy()))

            # Finalmente, predecimos 490 imagenes mas para calcular IoU y chamfer_distance
            idxs = [200, 1600, 3000, 4400, 5800, 7200, 8600, 10000]
            for i in range(7):
                target_images = im_validation[idxs[i]:idxs[i + 1]:20]  #.cuda()
                predicted_images = torch.zeros_like(target_images)
                control_points = model(target_images)

                # Renderizamos las imagenes predichas
                for bezier_cp in control_points:
                    # Calculamos la secuencia de puntos de esta curva
                    num_cps = model.num_cp * torch.ones(
                        70, dtype=torch.long, device=bezier_cp.device)
                    im_seq = bezier(
                        bezier_cp,
                        num_cps,
                        torch.linspace(0, 1, 150,
                                       device=num_cps.device).unsqueeze(0),
                        device=num_cps.device)
                    im_seq = torch.round(im_seq).long()
                    for j in range(70):
                        predicted_images[j, 0, im_seq[j, :, 0], im_seq[j, :,
                                                                       1]] = 1

                # Calculamos metricas
                iou_value += intersection_over_union(predicted_images,
                                                     target_images)
                chamfer_value += np.sum(
                    chamfer_distance(predicted_images[:, 0].cpu().numpy(),
                                     target_images[:, 0].cpu().numpy()))

            # Guardamos los resultados en tensorboard
            writer.add_scalar("Prediction/IoU", iou_value / 500, counter)
            writer.add_scalar("Prediction/Chamfer_distance",
                              chamfer_value / 500, counter)

        # Volvemos al modo train para la siguiente epoca
        model.train()
        print("Tiempo por epoca de", time.time() - t0)