Ejemplo n.º 1
0
                             batch_size=1,
                             shuffle=True,
                             drop_last=True)
dataloader_train = DataLoader(dataset=crc_dist,
                              batch_size=1,
                              shuffle=True,
                              drop_last=True)
dataloader_final_eval = DataLoader(dataset=crc_final,
                                   batch_size=1,
                                   shuffle=True,
                                   drop_last=True)

#####################################
# The models & the optimizers
#####################################
neural_map = NeuralTransportMap(space_dim=2, layers_dim=[128, 64])
opt_tm = torch.optim.Adagrad(neural_map.parameters(), lr=1e-2)

# how many times we send stats to tensorboard
n_stats_to_tensorboard = args.crayon_send_stats_iters
logger.info(
    f"Sending stats to tensorboard every {n_stats_to_tensorboard} iterations")

# how many times we save model
n_save: int = round(args.n_train / args.n_models_saved)
logger.info(
    f"Save models every {n_save} iterations, for a total of {args.n_models_saved}"
)

# initialize network to the identity
for iteration, data_dict in enumerate(dataloader_init):
dataloader_train = DataLoader(dataset=crc_dist,
                              batch_size=1,
                              shuffle=True,
                              drop_last=True)
dataloader_final_eval = DataLoader(dataset=crc_final,
                                   batch_size=1,
                                   shuffle=True,
                                   drop_last=True)

#####################################
# The models & the optimizers
#####################################
neural_plan = NeuralTransportPlan(space_dim=2, layers_dim=[128, 64])
opt_plan = torch.optim.Adagrad(neural_plan.parameters(), lr=1e-2)

neural_map = NeuralTransportMap(space_dim=2, layers_dim=[128, 64])
opt_tm = torch.optim.Adagrad(neural_map.parameters(), lr=1e-2)

# how many times we send stats to tensorboard
n_stats_to_tensorboard = args.crayon_send_stats_iters
logger.info(
    f"Sending stats to tensorboard every {n_stats_to_tensorboard} iterations")

# how many times we save model
n_save: int = round(args.n_train / args.n_models_saved)
logger.info(
    f"Save models every {n_save} iterations, for a total of {args.n_models_saved}"
)


def sink_iterate_find_potentials(batchSize: int, C: Tensor,