예제 #1
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))
    source_shape = dataset.mesh_vertex.unsqueeze(0).to(dtype=torch.float)
    source_face = dataset.mesh_face.unsqueeze(0)
    cage_shape = dataset.cage_vertex.unsqueeze(0).to(dtype=torch.float)
    cage_face = dataset.cage_face.unsqueeze(0)
    mesh = Mesh(vertices=cage_shape[0], faces=cage_face[0])
    build_gemm(mesh, cage_face[0])
    cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
    cage_edges = edge_vertex_indices(cage_face[0])

    # network
    net = networks.FixedSourceDeformer(
        opt,
        3,
        opt.num_point,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_shape.transpose(1, 2),
        template_faces=cage_face,
        source_vertices=source_shape.transpose(1, 2),
        source_faces=source_face).cuda()
    print(net)
    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)
    net.train()

    all_losses = losses.AllLosses(opt)

    # optimizer
    optimizer = torch.optim.Adam([{
        'params': net.nd_decoder.parameters()
    }, {
        "params": net.encoder.parameters()
    }],
                                 lr=opt.lr)

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "network2.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_Sa.ply".format(0)),
        net.source_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.source_faces[0].detach().cpu())
    pymesh.save_mesh_raw(
        os.path.join(opt.log_dir, "t{:06d}_template.ply".format(0)),
        net.template_vertices[0].transpose(0, 1).detach().cpu().numpy(),
        net.template_faces[0].detach().cpu())

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                max(int(opt.nepochs * 0.75),
                                                    1),
                                                gamma=0.5,
                                                last_epoch=-1)

    # train
    net.train()
    t = 0
    start_epoch = 0
    warmed_up = False
    mvc_weight = opt.mvc_weight
    opt.mvc_weight = 0

    os.makedirs(opt.log_dir, exist_ok=True)
    running_avg_loss = -1
    log_file = open(os.path.join(opt.log_dir, "loss_log.txt"), "a")
    log_interval = min(max(len(dataloader) // 5, 50), 200)
    save_interval = max(opt.nepochs // 10, 1)

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for epoch_t, data in enumerate(dataloader):
                progress = epoch_t / len(dataloader) + epoch
                warming_up = progress < opt.warmup_epochs
                if (opt.deform_template or opt.optimize_template) and (
                        progress >= opt.warmup_epochs) and (not warmed_up):
                    if opt.deform_template:
                        optimizer.add_param_group({
                            'params':
                            net.nc_decoder.parameters(),
                            'lr':
                            0.1 * opt.lr
                        })
                    if opt.optimize_template:
                        optimizer.add_param_group({
                            'params': net.template_vertices,
                            'lr': 0.1 * opt.lr
                        })
                    warmed_up = True
                    # start to compute mvc weight
                    opt.mvc_weight = mvc_weight
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="warmed_up")

                ############# get data ###########
                data = dataset.uncollate(data)
                data["cage_edge_points"] = cage_edge_points
                data["cage_edges"] = cage_edges
                data["source_shape"] = net.source_vertices.detach()
                data["source_face"] = net.source_faces.detach()

                ############# run network ###########
                optimizer.zero_grad()
                target_shape_t = data["target_shape"].transpose(1, 2)
                sample_idx = None

                if "sample_idx" in data:
                    sample_idx = data["sample_idx"]
                    if data["source_normals"] is not None:
                        data["source_normals"] = torch.gather(
                            data["source_normals"], 1,
                            sample_idx.unsqueeze(-1).expand(-1, -1, 3))

                outputs = net(target_shape_t, sample_idx)
                if opt.sfnormal_weight > 0 and ("source_mesh" in data
                                                and "source_mesh" is not None):
                    if outputs["deformed"].shape[1] == data[
                            "source_mesh"].shape[1]:
                        outputs["deformed_hr"] = outputs["deformed"]
                    else:
                        outputs["deformed_hr"] = deform_with_MVC(
                            outputs["cage"].expand(
                                data["source_mesh"].shape[0], -1, -1).detach(),
                            outputs["new_cage"], outputs["cage_face"].expand(
                                data["source_mesh"].shape[0], -1,
                                -1), data["source_mesh"])
                data["source_shape"] = outputs["source_shape"]

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 10 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        not warmed_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)
                    # save_ply(data["target_shape"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sb.ply".format(t)))
                    # save_ply(outputs["deformed"][0].detach().cpu().numpy(), os.path.join(opt.log_dir,"step-{:06d}-Sab.ply".format(t)))
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage1.ply".format(t)),
                    #               outputs["cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)
                    # write_trimesh(os.path.join(opt.log_dir, "step-{:06d}-cage2.ply".format(t)),
                    #               outputs["new_cage"][0].detach().cpu(), outputs["cage_face"][0].detach().cpu(), binary=True)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 10*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()

                if opt.alternate_cd:
                    optimize_C = (progress > opt.warmup_epochs) and (
                        t % (opt.c_step + opt.d_step)) > opt.d_step
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                        net.encoder.zero_grad()
                    else:
                        try:
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                # clamp_gradient_norm(net, 1)
                optimizer.step()
                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test_all(net=net)
예제 #2
0
def train():
    dataset = build_dataset(opt)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=tolerating_collate,
        num_workers=2,
        worker_init_fn=lambda id: np.random.seed(np.random.get_state()[1][0] +
                                                 id))

    if opt.dim == 3:
        # cage (1,N,3)
        init_cage_V, init_cage_Fs = loadInitCage([opt.template])
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        cage_edge_points_list = []
        cage_edges_list = []
        for F in init_cage_Fs:
            mesh = Mesh(vertices=init_cage_V[0], faces=F[0])
            build_gemm(mesh, F[0])
            cage_edge_points = torch.from_numpy(get_edge_points(mesh)).cuda()
            cage_edge_points_list.append(cage_edge_points)
            cage_edges_list = [edge_vertex_indices(F[0])]
    else:
        init_cage_V = generatePolygon(0, 0, 1.5, 0, 0, 0, opt.cage_deg)
        init_cage_V = torch.tensor([(x, y) for x, y in init_cage_V],
                                   dtype=torch.float).unsqueeze(0)
        cage_V_t = init_cage_V.transpose(1, 2).detach().cuda()
        init_cage_Fs = [
            torch.arange(opt.cage_deg, dtype=torch.int64).view(1, 1,
                                                               -1).cuda()
        ]

    # network
    net = networks.NetworkFull(
        opt,
        dim=opt.dim,
        bottleneck_size=opt.bottleneck_size,
        template_vertices=cage_V_t,
        template_faces=init_cage_Fs[-1],
    ).cuda()

    net.apply(weights_init)
    if opt.ckpt:
        load_network(net, opt.ckpt)

    all_losses = losses.AllLosses(opt)
    # optimizer
    optimizer = torch.optim.Adam([{
        "params": net.encoder.parameters()
    }, {
        "params": net.nd_decoder.parameters()
    }, {
        "params": net.merger.parameters()
    }],
                                 lr=opt.lr)

    if opt.full_net:
        optimizer.add_param_group({
            'params': net.nc_decoder.parameters(),
            'lr': 0.1 * opt.lr
        })
    if opt.optimize_template:
        optimizer.add_param_group({
            'params': net.template_vertices,
            'lr': opt.lr
        })

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                int(opt.nepochs * 0.4),
                                                gamma=0.1,
                                                last_epoch=-1)

    # train
    net.train()
    start_epoch = 0
    t = 0

    steps_C = 20
    steps_D = 20

    # train
    os.makedirs(opt.log_dir, exist_ok=True)
    shutil.copy2(__file__, opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "networks.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "losses.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "datasets.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "common.py"),
                 opt.log_dir)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "option.py"),
                 opt.log_dir)
    print(net)

    log_file = open(os.path.join(opt.log_dir, "training_log.txt"), "a")
    log_file.write(str(net) + "\n")

    log_interval = max(len(dataloader) // 5, 50)
    save_interval = max(opt.nepochs // 10, 1)
    running_avg_loss = -1

    with torch.autograd.detect_anomaly():
        if opt.epoch:
            start_epoch = opt.epoch % opt.nepochs
            t += start_epoch * len(dataloader)

        for epoch in range(start_epoch, opt.nepochs):
            for t_epoch, data in enumerate(dataloader):
                warming_up = epoch < opt.warmup_epochs
                progress = t_epoch / len(dataloader) + epoch
                optimize_C = (t % (steps_C + steps_D)) > steps_D

                ############# get data ###########
                data = dataset.uncollate(data)
                data = crisscross_input(data)
                if opt.dim == 3:
                    data["cage_edge_points"] = cage_edge_points_list[-1]
                    data["cage_edges"] = cage_edges_list[-1]
                source_shape, target_shape = data["source_shape"], data[
                    "target_shape"]

                ############# blending ############
                if opt.blend_style:
                    blend_alpha = torch.rand(
                        (source_shape.shape[0], 1),
                        dtype=torch.float32).to(device=source_shape.device)
                else:
                    blend_alpha = 1.0
                data["alpha"] = blend_alpha

                ############# run network ###########
                optimizer.zero_grad()
                # optimizer_C.zero_grad()
                # optimizer_D.zero_grad()
                source_shape_t = source_shape.transpose(1, 2)
                target_shape_t = target_shape.transpose(1, 2)
                outputs = net(source_shape_t, target_shape_t, data["alpha"])

                ############# get losses ###########
                current_loss = all_losses(data, outputs, progress)
                loss_sum = torch.sum(
                    torch.stack([v for v in current_loss.values()], dim=0))
                if running_avg_loss < 0:
                    running_avg_loss = loss_sum
                else:
                    running_avg_loss = running_avg_loss + (
                        loss_sum.item() - running_avg_loss) / (t + 1)

                if (t % log_interval
                        == 0) or (loss_sum > 5 * running_avg_loss):
                    log_str = "warming up {} e {:03d} t {:05d}: {}".format(
                        warming_up, epoch, t, ", ".join([
                            "{} {:.3g}".format(k,
                                               v.mean().item())
                            for k, v in current_loss.items()
                        ]))
                    print(log_str)
                    log_file.write(log_str + "\n")
                    log_outputs(opt, t, outputs, data)

                if loss_sum > 100 * running_avg_loss:
                    logger.info(
                        "loss ({}) > 5*running_average_loss ({}). Skip without update."
                        .format(loss_sum, 5 * running_avg_loss))
                    torch.cuda.empty_cache()
                    continue

                loss_sum.backward()
                if epoch < opt.warmup_epochs:
                    try:
                        net.nc_decoder.zero_grad()
                        net.encoder.zero_grad()
                    except AttributeError:
                        net.template_vertices.grad.zero_()

                if opt.alternate_cd:
                    optimize_C = (epoch > opt.warmup_epochs) and (
                        epoch % (opt.c_epoch + opt.d_epoch)) > opt.d_epoch
                    if optimize_C:
                        net.nd_decoder.zero_grad()
                    else:
                        try:
                            net.encoder.zero_grad()
                            net.nc_decoder.zero_grad()
                        except AttributeError:
                            net.template_vertices.grad.zero_()

                clamp_gradient(net, 0.1)
                optimizer.step()

                if (t + 1) % 500 == 0:
                    save_network(net,
                                 opt.log_dir,
                                 network_label="net",
                                 epoch_label="latest")

                t += 1

            if (epoch + 1) % save_interval == 0:
                save_network(net,
                             opt.log_dir,
                             network_label="net",
                             epoch_label=epoch)

            scheduler.step()
            if opt.eval:
                try:
                    test(net=net, save_subdir="epoch_{}".format(epoch))
                except Exception as e:
                    traceback.print_exc(file=sys.stdout)
                    logger.warn("Failed to run test", str(e))

    log_file.close()
    save_network(net, opt.log_dir, network_label="net", epoch_label="final")
    test(net=net)