コード例 #1
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 1024
    params["batch_size"] = args.batch_size
    params['use_gan'] = args.use_gan

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"],
                                split_dir=params['train_split'])
    #print(params["dataset_dir"])
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    G_model = Generator(params)
    G_model.apply(xavier_init)
    G_model = torch.nn.DataParallel(G_model).to(device)
    D_model = Discriminator(params, in_channels=3)
    D_model.apply(xavier_init)
    D_model = torch.nn.DataParallel(D_model).to(device)

    G_model.train()
    D_model.train()

    optimizer_D = Adam(D_model.parameters(),
                       lr=params["lr_D"],
                       betas=(0.9, 0.999))
    optimizer_G = Adam(G_model.parameters(),
                       lr=params["lr_G"],
                       betas=(0.9, 0.999))

    D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2)
    G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):
        D_scheduler.step()
        G_scheduler.step()
        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):
            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()

            start_t_batch = time.time()
            output_point_cloud = G_model(input_data)

            repulsion_loss = Loss_fn.get_repulsion_loss(
                output_point_cloud.permute(0, 2, 1))
            uniform_loss = Loss_fn.get_uniform_loss(
                output_point_cloud.permute(0, 2, 1))
            #print(output_point_cloud.shape,gt_data.shape)
            emd_loss = Loss_fn.get_emd_loss(
                output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1))

            if params['use_gan'] == True:
                fake_pred = D_model(output_point_cloud.detach())
                d_loss_fake = Loss_fn.get_discriminator_loss_single(
                    fake_pred, label=False)
                d_loss_fake.backward()
                optimizer_D.step()

                real_pred = D_model(gt_data.detach())
                d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred,
                                                                    label=True)
                d_loss_real.backward()
                optimizer_D.step()

                d_loss = d_loss_real + d_loss_fake

                fake_pred = D_model(output_point_cloud)
                g_loss = Loss_fn.get_generator_loss(fake_pred)

                #print(repulsion_loss,uniform_loss,emd_loss)
                total_G_loss=params['uniform_w']*uniform_loss+params['emd_w']*emd_loss+ \
                repulsion_loss*params['repulsion_w']+ g_loss*params['gan_w']
            else:
                #total_G_loss = params['uniform_w'] * uniform_loss + params['emd_w'] * emd_loss + \
                #               repulsion_loss * params['repulsion_w']
                total_G_loss=params['emd_w'] * emd_loss + \
                               repulsion_loss * params['repulsion_w']

            #total_G_loss=emd_loss
            total_G_loss.backward()
            optimizer_G.step()

            current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr']
            current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr']

            tb_logger.scalar_summary('repulsion_loss', repulsion_loss.item(),
                                     iter)
            tb_logger.scalar_summary('uniform_loss', uniform_loss.item(), iter)
            tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter)
            if params['use_gan'] == True:
                tb_logger.scalar_summary('d_loss', d_loss.item(), iter)
                tb_logger.scalar_summary('g_loss', g_loss.item(), iter)
            tb_logger.scalar_summary('lr_D', current_lr_D, iter)
            tb_logger.scalar_summary('lr_G', current_lr_G, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e, batch_id + 1, len(train_data_loader),
                "total_G_loss", total_G_loss.item(), "iter time",
                (time.time() - start_t_batch))
            print(msg)

            iter += 1
        if (e + 1) % params['model_save_interval'] == 0 and e > 0:
            model_save_dir = os.path.join(params['model_save_dir'],
                                          params['exp_name'])
            if os.path.exists(model_save_dir) == False:
                os.makedirs(model_save_dir)
            D_ckpt_model_filename = "D_iter_%d.pth" % (e)
            G_ckpt_model_filename = "G_iter_%d.pth" % (e)
            D_model_save_path = os.path.join(model_save_dir,
                                             D_ckpt_model_filename)
            G_model_save_path = os.path.join(model_save_dir,
                                             G_ckpt_model_filename)
            torch.save(D_model.module.state_dict(), D_model_save_path)
            torch.save(G_model.module.state_dict(), G_model_save_path)
コード例 #2
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 1024
    params["batch_size"] = args.batch_size
    params['use_gan'] = args.use_gan

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"])
    # print(params["dataset_dir"])
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    G_model = Generator_recon(params)
    G_model.apply(xavier_init)
    G_model = torch.nn.DataParallel(G_model).to(device)
    D_model = torch.nn.DataParallel(Discriminator(params,
                                                  in_channels=3)).to(device)

    G_model.train()
    D_model.train()

    optimizer_D = Adam(D_model.parameters(),
                       lr=params["lr_D"],
                       betas=(0.9, 0.999))
    optimizer_G = Adam(G_model.parameters(),
                       lr=params["lr_G"],
                       betas=(0.9, 0.999))

    D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2)
    G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):
        D_scheduler.step()
        G_scheduler.step()
        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):
            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()

            start_t_batch = time.time()
            output_point_cloud = G_model(input_data)

            emd_loss = Loss_fn.get_emd_loss(
                output_point_cloud.permute(0, 2, 1),
                input_data.permute(0, 2, 1))

            total_G_loss = emd_loss
            total_G_loss.backward()
            optimizer_G.step()

            current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr']
            current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr']

            tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter)
            tb_logger.scalar_summary('lr_D', current_lr_D, iter)
            tb_logger.scalar_summary('lr_G', current_lr_G, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e, batch_id + 1, len(train_data_loader),
                "total_G_loss", total_G_loss.item(), "iter time",
                (time.time() - start_t_batch))
            print(msg)

            if iter % params['model_save_interval'] == 0 and iter > 0:
                model_save_dir = os.path.join(params['model_save_dir'],
                                              params['exp_name'])
                if os.path.exists(model_save_dir) == False:
                    os.makedirs(model_save_dir)
                D_ckpt_model_filename = "D_iter_%d.pth" % (iter)
                G_ckpt_model_filename = "G_iter_%d.pth" % (iter)
                D_model_save_path = os.path.join(model_save_dir,
                                                 D_ckpt_model_filename)
                G_model_save_path = os.path.join(model_save_dir,
                                                 G_ckpt_model_filename)
                torch.save(D_model.module.state_dict(), D_model_save_path)
                torch.save(G_model.module.state_dict(), G_model_save_path)

            if iter % params['model_vis_interval'] == 0 and iter > 0:
                np_pcd = output_point_cloud.permute(
                    0, 2, 1)[0].detach().cpu().numpy()
                # print(np_pcd.shape)
                img = (np.array(visualize_point_cloud(np_pcd)) * 255).astype(
                    np.uint8)
                tb_logger.image_summary("images", img[np.newaxis, :], iter)

                gt_pcd = gt_data.permute(0, 2, 1)[0].detach().cpu().numpy()
                # print(gt_pcd.shape)
                gt_img = (np.array(visualize_point_cloud(gt_pcd)) *
                          255).astype(np.uint8)
                tb_logger.image_summary("gt", gt_img[np.newaxis, :], iter)

                input_pcd = input_data.permute(0, 2,
                                               1)[0].detach().cpu().numpy()
                input_img = (np.array(visualize_point_cloud(input_pcd)) *
                             255).astype(np.uint8)
                tb_logger.image_summary("input", input_img[np.newaxis, :],
                                        iter)
            iter += 1
コード例 #3
0
def train(args):
    start_t = time.time()
    params = get_train_options()
    params["exp_name"] = args.exp_name
    params["patch_num_point"] = 256
    params["batch_size"] = args.batch_size

    if args.debug:
        params["nepoch"] = 2
        params["model_save_interval"] = 3
        params['model_vis_interval'] = 3

    log_dir = os.path.join(params["model_save_dir"], args.exp_name)
    if os.path.exists(log_dir) == False:
        os.makedirs(log_dir)
    tb_logger = Logger(log_dir)

    #trainloader=PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split'])
    trainloader = PUGAN_Dataset(h5_file_path=params["dataset_dir"], npoint=256)
    num_workers = 4
    train_data_loader = data.DataLoader(dataset=trainloader,
                                        batch_size=params["batch_size"],
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True,
                                        drop_last=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ##########################################
    # Initialize generator and discriminator #
    ##########################################
    G_AB = Generator(params)
    G_AB.apply(xavier_init)
    G_AB = torch.nn.DataParallel(G_AB).to(device)

    G_BA = Downsampler(params)
    G_BA.apply(xavier_init)
    G_BA = torch.nn.DataParallel(G_BA).to(device)

    D_A = Discriminator(params, in_channels=3)
    D_A.apply(xavier_init)
    D_A = torch.nn.DataParallel(D_A).to(device)

    D_B = Discriminator(params, in_channels=3)
    D_B.apply(xavier_init)
    D_B = torch.nn.DataParallel(D_B).to(device)

    ########################################
    #Optimizers and Learning Rate scheduler#
    ########################################

    optimizer_D_A = Adam(D_A.parameters(),
                         lr=params["lr_D_A"],
                         betas=(0.9, 0.999))
    optimizer_D_B = Adam(D_B.parameters(),
                         lr=params["lr_D_B"],
                         betas=(0.9, 0.999))

    optimizer_G_AB = Adam(G_AB.parameters(),
                          lr=params["lr_G_AB"],
                          betas=(0.9, 0.999))
    optimizer_G_BA = Adam(G_BA.parameters(),
                          lr=params["lr_G_BA"],
                          betas=(0.9, 0.999))

    D_A_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2)
    G_AB_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2)
    D_B_scheduler = MultiStepLR(optimizer_D_A, [50, 80], gamma=0.2)
    G_BA_scheduler = MultiStepLR(optimizer_G_AB, [50, 80], gamma=0.2)

    Loss_fn = Loss()

    print("preparation time is %fs" % (time.time() - start_t))
    iter = 0
    for e in range(params["nepoch"]):

        for batch_id, (input_data, gt_data,
                       radius_data) in enumerate(train_data_loader):

            G_AB.train()
            G_BA.train()
            D_A.train()
            D_B.train()

            optimizer_G_AB.zero_grad()
            optimizer_D_A.zero_grad()
            optimizer_G_BA.zero_grad()
            optimizer_D_B.zero_grad()

            input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda()
            start_t_batch = time.time()

            output_point_cloud_high = G_AB(input_data)
            output_point_cloud_low = G_BA(gt_data)

            #####################################
            #               Loss                #
            #####################################
            repulsion_loss_AB = Loss_fn.get_repulsion_loss(
                output_point_cloud_high.permute(0, 2, 1))
            uniform_loss_AB = Loss_fn.get_uniform_loss(
                output_point_cloud_high.permute(0, 2, 1))
            repulsion_loss_BA = Loss_fn.get_repulsion_loss(
                output_point_cloud_low.permute(0, 2, 1))
            uniform_loss_BA = Loss_fn.get_uniform_loss(
                output_point_cloud_low.permute(0, 2, 1))
            emd_loss_AB = Loss_fn.get_emd_loss(
                output_point_cloud_high.permute(0, 2, 1),
                gt_data.permute(0, 2, 1))
            #emd_loss_BA = Loss_fn.get_emd_loss(output_point_cloud_low.permute(0, 2, 1), input_data.permute(0, 2, 1))

            #Cycle Loss
            recov_A = G_BA(output_point_cloud_high)
            ABA_repul_loss = Loss_fn.get_repulsion_loss(
                recov_A.permute(0, 2, 1))
            ABA_uniform_loss = Loss_fn.get_uniform_loss(
                recov_A.permute(0, 2, 1))

            recov_B = G_AB(output_point_cloud_low)
            BAB_repul_loss = Loss_fn.get_repulsion_loss(
                recov_B.permute(0, 2, 1))
            BAB_uniform_loss = Loss_fn.get_uniform_loss(
                recov_B.permute(0, 2, 1))
            BAB_emd_loss = Loss_fn.get_emd_loss(recov_B.permute(0, 2, 1),
                                                gt_data.permute(0, 2, 1))

            #G_AB loss
            fake_pred_B = D_A(output_point_cloud_high.detach())
            g_AB_loss = Loss_fn.get_generator_loss(fake_pred_B)
            total_G_AB_loss=g_AB_loss*params['gan_w_AB']+ BAB_repul_loss*params['repulsion_w_AB']+ \
            BAB_uniform_loss*params['uniform_w_AB']+ BAB_emd_loss*params['emd_w_AB']+ \
            params['uniform_w_AB']*uniform_loss_AB+params['emd_w_AB']*emd_loss_AB+ \
            repulsion_loss_AB*params['repulsion_w_AB']

            total_G_AB_loss.backward()
            optimizer_G_AB.step()

            #G_BA loss
            fake_pred_A = D_B(output_point_cloud_low.detach())
            g_BA_loss = Loss_fn.get_generator_loss(fake_pred_A)
            total_G_BA_loss=g_BA_loss*params['gan_w_BA']+ ABA_repul_loss*params['repulsion_w_BA']+ \
            repulsion_loss_BA*params['repulsion_w_BA']
            # ABA_uniform_loss*params['uniform_w_BA']+ \
            # params['uniform_w_BA']*uniform_loss_BA+ \

            total_G_BA_loss.backward()
            optimizer_G_BA.step()

            #Discriminator A loss
            fake_B_ = fake_A_buffer.push_and_pop(output_point_cloud_high)
            fake_pred_B = D_A(fake_B_.detach())
            d_A_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_B,
                                                                  label=False)

            real_pred_B = D_A(gt_data.detach())
            d_A_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_B,
                                                                  label=True)

            d_A_loss = d_A_loss_real + d_A_loss_fake
            d_A_loss.backward()
            optimizer_D_A.step()

            #Discriminator B loss
            fake_A_ = fake_B_buffer.push_and_pop(output_point_cloud_low)
            fake_pred_A = D_B(fake_A_.detach())
            d_B_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred_A,
                                                                  label=False)

            real_pred_A = D_B(input_data.detach())
            d_B_loss_real = Loss_fn.get_discriminator_loss_single(real_pred_A,
                                                                  label=True)
            d_B_loss = d_B_loss_real + d_B_loss_fake
            d_B_loss.backward()
            optimizer_D_B.step()

            #Learning rate scheduler#
            current_lr_D_A = optimizer_D_A.state_dict(
            )['param_groups'][0]['lr']
            current_lr_G_AB = optimizer_G_AB.state_dict(
            )['param_groups'][0]['lr']
            current_lr_D_B = optimizer_D_B.state_dict(
            )['param_groups'][0]['lr']
            current_lr_G_BA = optimizer_G_BA.state_dict(
            )['param_groups'][0]['lr']

            # tb_logger.scalar_summary('repulsion_loss_AB', repulsion_loss_AB.item(), iter)
            # tb_logger.scalar_summary('uniform_loss_AB', uniform_loss_AB.item(), iter)
            # tb_logger.scalar_summary('repulsion_loss_BA', repulsion_loss_BA.item(), iter)
            # tb_logger.scalar_summary('uniform_loss_BA', uniform_loss_BA.item(), iter)
            # tb_logger.scalar_summary('emd_loss_AB', emd_loss_AB.item(), iter)

            tb_logger.scalar_summary('d_A_loss', d_A_loss.item(), iter)
            tb_logger.scalar_summary('g_AB_loss', g_AB_loss.item(), iter)
            tb_logger.scalar_summary('Total_G_AB_loss', total_G_AB_loss.item(),
                                     iter)
            tb_logger.scalar_summary('lr_D_A', current_lr_D_A, iter)
            tb_logger.scalar_summary('lr_G_AB', current_lr_G_AB, iter)
            tb_logger.scalar_summary('d_B_loss', d_B_loss.item(), iter)
            tb_logger.scalar_summary('g_BA_loss', g_BA_loss.item(), iter)
            tb_logger.scalar_summary('Total_G_BA_loss', total_G_BA_loss.item(),
                                     iter)
            tb_logger.scalar_summary('lr_D_B', current_lr_D_B, iter)
            tb_logger.scalar_summary('lr_G_BA', current_lr_G_BA, iter)

            msg = "{:0>8},{}:{}, [{}/{}], {}: {}, {}: {}, {}:{}, {}: {},{}: {}".format(
                str(datetime.timedelta(seconds=round(time.time() - start_t))),
                "epoch", e + 1, batch_id + 1, len(train_data_loader),
                "total_G_AB_loss", total_G_AB_loss.item(), "total_G_BA_loss",
                total_G_BA_loss.item(),
                "iter time", (time.time() - start_t_batch), "d_A_loss",
                d_A_loss.item(), "d_B_loss", d_B_loss.item())
            print(msg)

            iter += 1

        D_A_scheduler.step()
        G_AB_scheduler.step()
        D_B_scheduler.step()
        G_BA_scheduler.step()

        if (e + 1) % params['model_save_interval'] == 0 and e > 0:
            model_save_dir = os.path.join(params['model_save_dir'],
                                          params['exp_name'])
            if os.path.exists(model_save_dir) == False:
                os.makedirs(model_save_dir)
            D_A_ckpt_model_filename = "D_A_iter_%d.pth" % (e + 1)
            G_AB_ckpt_model_filename = "G_AB_iter_%d.pth" % (e + 1)
            D_A_model_save_path = os.path.join(model_save_dir,
                                               D_A_ckpt_model_filename)
            G_AB_model_save_path = os.path.join(model_save_dir,
                                                G_AB_ckpt_model_filename)
            D_B_ckpt_model_filename = "D_B_iter_%d.pth" % (e + 1)
            G_BA_ckpt_model_filename = "G_BA_iter_%d.pth" % (e + 1)
            model_ckpt_model_filename = "Cyclegan_iter_%d.pth" % (e + 1)
            D_B_model_save_path = os.path.join(model_save_dir,
                                               D_B_ckpt_model_filename)
            G_BA_model_save_path = os.path.join(model_save_dir,
                                                G_BA_ckpt_model_filename)
            model_all_path = os.path.join(model_save_dir,
                                          model_ckpt_model_filename)
            torch.save(
                {
                    'G_AB_state_dict': G_AB.module.state_dict(),
                    'G_BA_state_dict': G_BA.module.state_dict(),
                    'D_A_state_dict': D_A.module.state_dict(),
                    'D_B_state_dict': D_B.module.state_dict(),
                    'optimizer_G_AB_state_dict': optimizer_G_AB.state_dict(),
                    'optimizer_G_BA_state_dict': optimizer_G_BA.state_dict(),
                    'optimizer_D_A_state_dict': optimizer_D_A.state_dict(),
                    'optimizer_D_B_state_dict': optimizer_D_B.state_dict()
                }, model_all_path)
            torch.save(D_A.module.state_dict(), D_A_model_save_path)
            torch.save(G_AB.module.state_dict(), G_AB_model_save_path)
            torch.save(D_B.module.state_dict(), D_B_model_save_path)
            torch.save(G_BA.module.state_dict(), G_BA_model_save_path)