def main(args):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"    

    importer = fileImport()
    env_data_path = args.env_data_path
    path_data_path = args.path_data_path
    pcd_data_path = args.pointcloud_data_path

    envs = importer.environments_import(env_data_path + args.envs_file)

    print("Loading obstacle data...\n")
    dataset_train, targets_train, pc_inds_train, obstacles = load_dataset_end2end(
        envs, path_data_path, pcd_data_path, args.path_data_file, importer, NP=1000)

    print("Loaded dataset, targets, and pontcloud obstacle vectors: ")
    print(str(len(dataset_train)) + " " +
        str(len(targets_train)) + " " + str(len(pc_inds_train)))
    print("\n")

    if not os.path.exists(args.trained_model_path):
        os.makedirs(args.trained_model_path)

    # Build the models
    mlp = MLP(args.mlp_input_size, args.mlp_output_size)
    encoder = Encoder(args.enc_input_size, args.enc_output_size)

    if torch.cuda.is_available():
        encoder.cuda()
        mlp.cuda()


    # Loss and Optimizer
    criterion = nn.MSELoss()
    params = list(encoder.parameters())+list(mlp.parameters())
    optimizer = torch.optim.Adagrad(params, lr=args.learning_rate)

    total_loss = []
    epoch = 1

    sm = 90  # start saving models after 100 epochs

    print("Starting epochs...\n")
    # epoch=1
    done = False
    for epoch in range(args.num_epochs):
        # while (not done)
        start = time.time()
        print("epoch" + str(epoch))
        avg_loss = 0
        for i in range(0, len(dataset_train), args.batch_size):
            # Forward, Backward and Optimize
            # zero gradients
            encoder.zero_grad()
            mlp.zero_grad()

            # convert to pytorch tensors and Varialbes
            bi, bt, bobs = get_input(
                i, dataset_train, targets_train, pc_inds_train, obstacles, args.batch_size)
            bi = to_var(bi)
            bt = to_var(bt)
            bobs = to_var(bobs)

            # forward pass through encoder
            h = encoder(bobs)

            # concatenate encoder output with dataset input
            inp = torch.cat((bi, h), dim=1)

            # forward pass through mlp
            bo = mlp(inp)

            # compute overall loss and backprop all the way
            loss = criterion(bo, bt)
            avg_loss = avg_loss+loss.data
            loss.backward()
            optimizer.step()

        print("--average loss:")
        print(avg_loss/(len(dataset_train)/args.batch_size))
        total_loss.append(avg_loss/(len(dataset_train)/args.batch_size))
        # Save the models
        if epoch == sm:
            print("\nSaving model\n")
            print("time: " + str(time.time() - start))
            torch.save(encoder.state_dict(), os.path.join(
                args.trained_model_path, 'cae_encoder_'+str(epoch)+'.pkl'))
            torch.save(total_loss, 'total_loss_'+str(epoch)+'.dat')

            model_path = 'mlp_PReLU_ae_dd'+str(epoch)+'.pkl'
            torch.save(mlp.state_dict(), os.path.join(
                args.trained_model_path, model_path))
            if (epoch != 1):
                sm = sm+10  # save model after every 50 epochs from 100 epoch ownwards

    torch.save(total_loss, 'total_loss.dat')
    model_path = 'mlp_PReLU_ae_dd_final.pkl'
    torch.save(mlp.state_dict(), os.path.join(args.trained_model_path, model_path))
Beispiel #2
0
                            num_workers=8)
"""
Model settings 
"""
mesh = kal.rep.TriangleMesh.from_obj('386.obj')
if args.device == "cuda":
    mesh.cuda()
initial_verts = mesh.vertices.clone()

model = Encoder(4, 5, args.batchsize, 137,
                mesh.vertices.shape[0]).to(args.device)

loss_fn = kal.metrics.point.chamfer_distance
loss_edge = kal.metrics.mesh.edge_length

optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Create log directory, if it doesn't already exist
args.logdir = os.path.join(args.logdir, args.expid)
if not os.path.isdir(args.logdir):
    os.makedirs(args.logdir)
    print('Created dir:', args.logdir)

# Log all commandline args
with open(os.path.join(args.logdir, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)


class Engine(object):
    """Engine that runs training and inference.
	Args