Exemplo n.º 1
0
def train_gpu(rank,world_size,opt,dataset):
    torch.cuda.set_device(opt.gpu_ids[rank])
    signal.signal(signal.SIGINT, signal_handler) #to really kill the process
    signal.signal(signal.SIGTERM, signal_handler)
    if len(opt.gpu_ids)>1:
        setup(rank, world_size,opt.ddp_port)
    dataloader = create_dataloader(opt,rank,dataset)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    model = create_model(opt,rank)      # create a model given opt.model and other options

    if hasattr(model,'data_dependent_initialize'):
        data=next(iter(dataloader))
        model.data_dependent_initialize(data)

    model.setup(opt)               # regular setup: load and print networks; create schedulers

    if len(opt.gpu_ids)>1:
        model.parallelize(rank)
    else:
        model.single_gpu()

    if rank==0:
        visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
    total_iters = 0                # the total number of training iterations

    if rank == 0:
        model.real_A_val,model.real_B_val = dataset.get_validation_set(opt.pool_size)
        model.real_A_val,model.real_B_val=model.real_A_val.to(model.device),model.real_B_val.to(model.device)
        
    if rank==0 and opt.display_networks:
        data=next(iter(dataloader))
        for path in model.save_networks_img(data):
            visualizer.display_img(path+'.png')

    for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        if rank==0:
            visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        for i, data in enumerate(dataloader):  # inner loop (minibatch) within one epoch
            
            iter_start_time = time.time()  # timer for computation per iteration
            t_data_mini_batch = iter_start_time - iter_data_time
            
            model.set_input(data)         # unpack data from dataloader and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            
            t_comp = (time.time() - iter_start_time) / opt.batch_size

            batch_size=model.get_current_batch_size() * len(opt.gpu_ids)
            total_iters += batch_size
            epoch_iter += batch_size

            if rank == 0:
                if total_iters % opt.display_freq < batch_size:   # display images on visdom and save images to a HTML file
                    save_result = total_iters % opt.update_html_freq == 0
                    model.compute_visuals()
                    visualizer.display_current_results(model.get_current_visuals(), epoch, save_result,params=model.get_display_param())

                if total_iters % opt.print_freq < batch_size :    # print training losses and save logging information to the disk
                    losses = model.get_current_losses()
                    visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data_mini_batch)
                    if opt.display_id > 0:
                        visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

                if total_iters % opt.save_latest_freq < batch_size :   # cache our latest model every <save_latest_freq> iterations
                    print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                    save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                    model.save_networks(save_suffix)

                if total_iters % opt.fid_every < batch_size and opt.compute_fid:
                    model.compute_fid(epoch,total_iters)
                    if opt.display_id > 0:
                        fids=model.get_current_fids()
                        visualizer.plot_current_fid(epoch, float(epoch_iter) / dataset_size, fids)

                if total_iters % opt.D_accuracy_every < batch_size and opt.compute_D_accuracy:
                    model.compute_D_accuracy()
                    if opt.display_id > 0:
                        accuracies=model.get_current_D_accuracies()
                        visualizer.plot_current_D_accuracies(epoch, float(epoch_iter) / dataset_size, accuracies)

                if total_iters % opt.display_freq < batch_size and opt.APA:
                    if opt.display_id > 0:
                        p=model.get_current_APA_prob()
                        visualizer.plot_current_APA_prob(epoch, float(epoch_iter) / dataset_size, p)
                    
    
                iter_data_time = time.time()
            
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            if rank == 0:
                print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
                model.save_networks('latest')
                model.save_networks(epoch)

        if rank == 0:
            print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))    
        model.update_learning_rate()                     # update learning rates at the end of every epoch.
Exemplo n.º 2
0
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq < model.opt.batch_size:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            if total_iters % opt.fid_every < model.opt.batch_size and opt.compute_fid:
                model.compute_fid(epoch, total_iters)
                if opt.display_id > 0:
                    fids = model.get_current_fids()
                    visualizer.plot_current_fid(
                        epoch,
                        float(epoch_iter) / dataset_size, fids)

            iter_data_time = time.time()

        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.n_epochs + opt.n_epochs_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.