Ejemplo n.º 1
0
def tester(args):
    print ("Evaluating Thopte and Prabhdeep's Model...")


    image_saved_path = parameters.images_dir
    if not os.path.exists(image_saved_path):
        os.makedirs(image_saved_path)

    if args.use_visdom == True:
        vis = visdom.Visdom()

    save_file_path = parameters.output_dir + '/' + args.model_name
    pretrained_file_path_G = save_file_path+'/'+'G.pth'
    pretrained_file_path_D = save_file_path+'/'+'D.pth'
    
    print (pretrained_file_path_G)

    D = net_D(args)
    G = net_G(args)

    if not torch.cuda.is_available():
        G.load_state_dict(torch.load(pretrained_file_path_G, map_location={'cuda:0': 'cpu'}))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    else:
        G.load_state_dict(torch.load(pretrained_file_path_G))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    
    print ('visualizing sid')
    
    G.to(parameters.device)
    D.to(parameters.device)
    G.eval()
    D.eval()

    N = 8

    for i in range(N):
        z = generateZ(args, 1)
        
        fake = G(z)
        samples = fake.unsqueeze(dim=0).detach().numpy()

        y_prob = D(fake)
        y_real = torch.ones_like(y_prob)
        if args.use_visdom == False:
            SavePloat_Voxels(samples, image_saved_path, 'tester__'+str(i))
        else:
            plotVoxelVisdom(samples[0,:], vis, "tester_"+str(i))
Ejemplo n.º 2
0
def trainer(args):

    # added for output dir
    save_file_path = params.output_dir + '/' + args.model_name
    print (save_file_path)  # ../outputs/dcgan
    if not os.path.exists(save_file_path):
        os.makedirs(save_file_path)

    # for using tensorboard
    if args.logs:
        model_uid = datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        writer = SummaryWriter(params.output_dir+'/'+args.model_name+'/logs_'+model_uid+'_'+args.logs+'/')

    ####################################################
    # datset define
    # dsets_path = args.input_dir + args.data_dir + "train/"


    
    # print (dset_len["train"])
    ####################################################

    # model define
    D = net_D(args)
    G = net_G(args)
    Q = net_Q(args)

    # Load
    # G.load_state_dict(torch.load(params.output_dir+'/'+args.model_name+'/'+'G589.pth'))
    # Q.load_state_dict(torch.load(params.output_dir+'/'+args.model_name+'/'+'Q589.pth'))
    # D.load_state_dict(torch.load(params.output_dir+'/'+args.model_name+'/'+'D589.pth', map_location={'cuda:0': 'cpu'}))

    # print total number of parameters in a model
    # x = sum(p.numel() for p in G.parameters() if p.requires_grad)
    # print (x)
    # x = sum(p.numel() for p in D.parameters() if p.requires_grad)
    # print (x)

    D_solver = optim.Adam(D.parameters(), lr=params.d_lr, betas=params.beta)
    # D_solver = optim.SGD(D.parameters(), lr=args.d_lr, momentum=0.9)
    G_solver = optim.Adam(G.parameters(), lr=params.g_lr, betas=params.beta)
    Q_solver = optim.Adam(Q.parameters(), lr=params.d_lr, betas=params.beta)

    D.to(params.device)
    G.to(params.device)
    Q.to(params.device)
    

    # criterion_D = nn.BCELoss()
    criterion_D = nn.MSELoss()
    criterion_G = nn.L1Loss()
    criterion_Q = nn.MSELoss()

    itr_val = -1
    itr_train = -1


    noise_prob = 0.2

    for epoch in range(params.epochs):

        #####################################################################################
        if epoch % 10 < 5:
            data_idx = 0
        else:
            data_idx = 1
        dsets_path = params.data_dir + params.model_dir[data_idx] + "30/train/"
        #dsets_path = params.data_dir + params.model_dir

        # if params.cube_len == 64:
        #     dsets_path = params.data_dir + params.model_dir + "30/train64/"

        print (dsets_path)   # ../volumetric_data/chair/30/train/

        train_dsets = ShapeNetDataset(dsets_path, args, "train")
        #print("train_dsets : ", train_dsets.shape)
        # val_dsets = ShapeNetDataset(dsets_path, args, "val")
        
        train_dset_loaders = torch.utils.data.DataLoader(train_dsets, batch_size=params.batch_size, shuffle=True, num_workers=1)
        # val_dset_loaders = torch.utils.data.DataLoader(val_dsets, batch_size=args.batch_size, shuffle=True, num_workers=1)
        
        dset_len = {"train": len(train_dsets)}
        dset_loaders = {"train": train_dset_loaders}
        ########################################################################################################

        start = time.time()
        
        for phase in ['train']:
            if phase == 'train':
                # if args.lrsh:
                #     D_scheduler.step()
                D.train()
                G.train()
                Q.train()
            else:
                D.eval()
                G.eval()
                Q.eval()

            running_loss_G = 0.0
            running_loss_D = 0.0
            running_loss_Q = 0.0
            running_loss_adv_G = 0.0

            for i, X in enumerate(tqdm(dset_loaders[phase])):

                if np.random.random_sample(1) < noise_prob:
                    noise_idx = np.random.randint(X.size()[0], size=1)
                    X[noise_idx] = torch.FloatTensor(np.random.binomial(1, 0.5, (64,64,64)))

                # if phase == 'val':
                #     itr_val += 1

                if phase == 'train':
                    itr_train += 1

                #print (X.shape)    
                X = X.to(params.device)
                #print (X)
                #print ('X : ',X.shape)
                
                batch = X.size()[0]
                # print (batch)

                Z = generateZ(args, batch, data_idx)
                # print (Z.size())

                # ============= Train the discriminator =============#
                d_real = D(X)

                c_real = Q(X)
                if data_idx == 0:
                    c_real_labels = torch.FloatTensor(np.array([[1.]for x in range(batch)])).view(-1).to(params.device)
                else:
                    c_real_labels = torch.FloatTensor(np.array([[2.]for x in range(batch)])).view(-1).to(params.device)
                c_real_loss = criterion_Q(c_real, c_real_labels)

                fake = G(Z)
                d_fake = D(fake)


                real_labels = torch.ones_like(d_real).to(params.device)
                fake_labels = torch.zeros_like(d_fake).to(params.device)
                # print (d_fake.size(), fake_labels.size())

                if params.soft_label:
                    real_labels = torch.Tensor(batch).uniform_(0.7, 1.2).to(params.device)
                    fake_labels = torch.Tensor(batch).uniform_(0, 0.3).to(params.device)

                # print (d_real.size(), real_labels.size())
                d_real_loss = criterion_D(d_real, real_labels)
                

                d_fake_loss = criterion_D(d_fake, fake_labels)

                d_loss = d_real_loss + d_fake_loss

                # no deleted
                d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
                d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu),0))


                if d_total_acu < params.d_thresh:
                    D.zero_grad()
                    d_loss.backward()
                    D_solver.step()

                    Q.zero_grad()
                    c_real_loss.backward()
                    Q_solver.step()

                # =============== Train the generator ===============#
                
                Z = generateZ(args, batch, data_idx)

                # print (X)
                fake = G(Z) # generated fake: 0-1, X: 0/1
                d_fake = D(fake)
                c_fake = Q(fake)

                adv_g_loss = criterion_D(d_fake, real_labels)
                # print (fake.size(), X.size())
                adv_c_loss = criterion_Q(c_fake, c_real_labels)

                # recon_g_loss = criterion_D(fake, X)
                recon_g_loss = criterion_G(fake, X)

                # g_loss = recon_g_loss + params.adv_weight * adv_g_loss
                g_loss = adv_g_loss + adv_c_loss

                if args.local_test:
                    # print('Iteration-{} , D(x) : {:.4} , G(x) : {:.4} , D(G(x)) : {:.4}'.format(itr_train, d_loss.item(), recon_g_loss.item(), adv_g_loss.item()))
                    print('Iteration-{} , D(x) : {:.4}, D(G(x)) : {:.4}'.format(itr_train, d_loss.item(), adv_g_loss.item()))

                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                G_solver.step()

                # =============== logging each 10 iterations ===============#

                running_loss_G += recon_g_loss.item() * X.size(0)
                running_loss_D += d_loss.item() * X.size(0)
                running_loss_adv_G += adv_g_loss.item() * X.size(0)

                if args.logs:
                    loss_G = {
                        'adv_loss_G': adv_g_loss,
                        'recon_loss_G': recon_g_loss,   
                    }

                    loss_D = {
                        'adv_real_loss_D': d_real_loss,
                        'adv_fake_loss_D': d_fake_loss,
                    }

                    # if itr_val % 10 == 0 and phase == 'val':
                    #     save_val_log(writer, loss_D, loss_G, itr_val)

                    if itr_train % 10 == 0 and phase == 'train':
                        save_train_log(writer, loss_D, loss_G, itr_train)

           
            # =============== each epoch save model or save image ===============#
            epoch_loss_G = running_loss_G / dset_len[phase]
            epoch_loss_D = running_loss_D / dset_len[phase]
            epoch_loss_adv_G = running_loss_adv_G / dset_len[phase]


            end = time.time()
            epoch_time = end - start


            print('Epochs-{} ({}) , D(x) : {:.4}, D(G(x)) : {:.4}'.format(epoch, phase, epoch_loss_D, epoch_loss_adv_G))
            print ('Elapsed Time: {:.4} min'.format(epoch_time/60.0))

            if (epoch + 1) % params.model_save_step == 0:

                print ('model_saved, images_saved...')
                torch.save(G.state_dict(), params.output_dir + '/' + args.model_name + '/' + 'G' +str(epoch) + '.pth')
                torch.save(D.state_dict(), params.output_dir + '/' + args.model_name + '/' + 'D' + str(epoch)+ '.pth')
                torch.save(Q.state_dict(), params.output_dir + '/' + args.model_name + '/' + 'Q' + str(epoch)+ '.pth')

                samples = fake.cpu().data[:8].squeeze().numpy()
                # print (samples.shape)
                # image_saved_path = '../images'
                image_saved_path = params.images_dir
                if not os.path.exists(image_saved_path):
                    os.makedirs(image_saved_path)

                SavePloat_Voxels(samples, image_saved_path, epoch)
Ejemplo n.º 3
0
def tester(args):
    print ('Evaluation Mode...')

    # image_saved_path = '../images'
    image_saved_path = params.images_dir
    if not os.path.exists(image_saved_path):
        os.makedirs(image_saved_path)

    if args.use_visdom == True:
        vis = visdom.Visdom()

    save_file_path = params.output_dir + '/' + args.model_name
    pretrained_file_path_G = save_file_path+'/'+'G.pth'
    pretrained_file_path_D = save_file_path+'/'+'D.pth'
    
    print (pretrained_file_path_G)

    D = net_D(args)
    G = net_G(args)

    if not torch.cuda.is_available():
        G.load_state_dict(torch.load(pretrained_file_path_G, map_location={'cuda:0': 'cpu'}))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    else:
        G.load_state_dict(torch.load(pretrained_file_path_G))
        D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
    
    print ('visualizing model')
    
    # test generator
    # test_gen(args)
    G.to(params.device)
    D.to(params.device)
    G.eval()
    D.eval()

    # test_z = np.load("test_z.npy")
    # print (test_z.shape)
    # N = test_z.shape[0]

    N = 8

    for i in range(N):
        # z = test_z[i,:]
        # z = torch.FloatTensor(z)
        
        z = generateZ(args, 1)
        
        # print (z.size())
        fake = G(z)
        samples = fake.unsqueeze(dim=0).detach().cpu().numpy()
        # print (samples.shape)
        # print (fake)
        y_prob = D(fake)
        y_real = torch.ones_like(y_prob)
        # criterion = nn.BCELoss()
        # print (y_prob.item(), criterion(y_prob, y_real).item())

        ### visualization
        
        SavePloat_Voxels(samples, image_saved_path, 'tester_norm_'+str(i))
Ejemplo n.º 4
0
def trainer(args):

    
    save_file_path = parameters.output_dir + '/' + args.model_name
    print (save_file_path) 
    if not os.path.exists(save_file_path):
        os.makedirs(save_file_path)

    
    if args.logs:
        model_uid = datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        writer = SummaryWriter(parameters.output_dir+'/'+args.model_name+'/logs_'+model_uid+'_'+args.logs+'/')
    
    dsets_path = parameters.data_dir + parameters.model_dir + "30/train/"
    

    print (dsets_path)   

    train_dsets = ShapeNetDataset(dsets_path, args, "train")
    
    
    train_dset_loaders = torch.utils.data.DataLoader(train_dsets, batch_size=parameters.batch_size, shuffle=True, num_workers=1)
    
    dset_len = {"train": len(train_dsets)}
    dset_loaders = {"train": train_dset_loaders}
    

    D = net_D(args)
    G = net_G(args)


    D_solver = optim.Adam(D.parameters(), lr=parameters.d_lr, betas=parameters.beta)
    G_solver = optim.Adam(G.parameters(), lr=parameters.g_lr, betas=parameters.beta)

    D.to(parameters.device)
    G.to(parameters.device)
    
    criterion_D = nn.MSELoss()
    criterion_G = nn.L1Loss()

    itr_val = -1
    iter_tr = -1

    for epoch in range(parameters.epochs):

        start = time.time()
        
        for phase in ['train']:
            if phase == 'train':
                D.train()
                G.train()
            else:
                D.eval()
                G.eval()

            running_loss_G = 0.0
            running_loss_D = 0.0
            running_loss_adv_G = 0.0

            for i, X in enumerate(tqdm(dset_loaders[phase])):

                if phase == 'train':
                    iter_tr += 1

                X = X.to(parameters.device)
                
                batch = X.size()[0]
                

                Z = generateZ(args, batch)

                #Sid
                d_real = D(X)

                

                fake = G(Z)
                d_fake = D(fake)

                real_labels = torch.ones_like(d_real).to(parameters.device)
                fake_labels = torch.zeros_like(d_fake).to(parameters.device)
                

                if parameters.soft_label:
                    real_labels = torch.Tensor(batch).uniform_(0.7, 1.2).to(parameters.device)
                    fake_labels = torch.Tensor(batch).uniform_(0, 0.3).to(parameters.device)

                
                d_real_loss = criterion_D(d_real, real_labels)
                

                d_fake_loss = criterion_D(d_fake, fake_labels)

                d_loss = d_real_loss + d_fake_loss

                
                d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
                d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu),0))


                if d_total_acu < parameters.d_thresh:
                    D.zero_grad()
                    d_loss.backward()
                    D_solver.step()

                #Thopte
                
                Z = generateZ(args, batch)

                
                fake = G(Z) 
                d_fake = D(fake)

                adv_g_loss = criterion_D(d_fake, real_labels)
            
                recon_g_loss = criterion_G(fake, X)
                g_loss = adv_g_loss

                if args.local_test:
                    print('Iteration-{} , D(x) : {:.4}, D(G(x)) : {:.4}'.format(iter_tr, d_loss.item(), adv_g_loss.item()))

                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                G_solver.step()

                

                running_loss_G += recon_g_loss.item() * X.size(0)
                running_loss_D += d_loss.item() * X.size(0)
                running_loss_adv_G += adv_g_loss.item() * X.size(0)

                if args.logs:
                    loss_G = {
                        'adv_loss_G': adv_g_loss,
                        'recon_loss_G': recon_g_loss,   
                    }

                    loss_D = {
                        'adv_real_loss_D': d_real_loss,
                        'adv_fake_loss_D': d_fake_loss,
                    }

                
                    if iter_tr % 10 == 0 and phase == 'train':
                        save_train_logs(writer, loss_D, loss_G, iter_tr)

           
            
            epoch_loss_G = running_loss_G / dset_len[phase]
            epoch_loss_D = running_loss_D / dset_len[phase]
            epoch_loss_adv_G = running_loss_adv_G / dset_len[phase]


            end = time.time()
            epoch_time = end - start


            print('Epochs-{} ({}) , D(x) : {:.4}, D(G(x)) : {:.4}'.format(epoch, phase, epoch_loss_D, epoch_loss_adv_G))
            print ('Elapsed Time: {:.4} min'.format(epoch_time/60.0))

            if (epoch + 1) % parameters.model_save_step == 0:

                print ('model_saved, images_saved...')
                torch.save(G.state_dict(), parameters.output_dir + '/' + args.model_name + '/' + 'G' + '.pth')
                torch.save(D.state_dict(), parameters.output_dir + '/' + args.model_name + '/' + 'D' + '.pth')

                samples = fake.cpu().data[:8].squeeze().numpy()
                image_saved_path = parameters.images_dir
                if not os.path.exists(image_saved_path):
                    os.makedirs(image_saved_path)

                Save_Plot_Voxels(samples, image_saved_path, epoch)