Exemplo n.º 1
0
    def __init__(self, args, model, evaluation, writer):
        self.args = args
        self.model = model
        self.evaluation = evaluation
        self.writer = writer

        self.nepochs = args.nepochs

        # logging testing
        self.log_loss = plugins.Logger(args.logs_dir, 'TestLogger.txt',
                                       args.save_results)
        self.params_log = ['ACC', 'STD']
        self.log_loss.register(self.params_log)

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {}
        for param in self.params_log:
            self.params_monitor.update({param: {'dtype': 'running_mean'}})
        self.monitor.register(self.params_monitor)

        # display training progress
        self.print_formatter = 'Test [%d/%d]] '
        for item in self.params_log:
            self.print_formatter += item + " %.4f "

        self.losses = {}
Exemplo n.º 2
0
    def __init__(self, args, netD, netG, netE):
        self.netD = netD
        self.netG = netG
        self.netE = netE

        self.args = args
        self.nchannels = args.nchannels
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nz = args.nz
        self.wcom = args.disc_loss_weight
        self.cuda = args.cuda
        self.citers = args.citers
        self.lr = args.learning_rate_vae
        self.momentum = args.momentum
        self.batch_size = args.batch_size
        self.use_encoder = args.use_encoder

        self.input = Variable(torch.FloatTensor(self.batch_size,
                                                self.nchannels,
                                                self.resolution_high,
                                                self.resolution_wide),
                              volatile=True).cuda()
        self.epsilon = Variable(torch.randn(self.batch_size, self.nz),
                                volatile=True).cuda()
        self.noise = Variable(torch.FloatTensor(self.batch_size, self.nz, 1,
                                                1).normal_(0, 1),
                              volatile=True)

        if args.cuda:
            self.input = self.input.cuda()
            self.epsilon = self.epsilon.cuda()
            self.noise = self.noise.cuda()

        self.log_eval_loss = plugins.Logger(args.logs, 'Generation.txt')
        self.params_eval_loss = [
            'Image', 'SSIM_1', 'SSIM_2', 'SSIM_3', 'SSIM_4', 'PSNR_1',
            'PSNR_2', 'PSNR_3', 'PSNR_4', 'DiscScore_1', 'DiscScore_2',
            'DiscScore_3', 'DiscScore_4'
        ]
        self.log_eval_loss.register(self.params_eval_loss)

        self.losses = {}
        self.log_eval_monitor = plugins.Monitor()
        self.params_eval_monitor = [
            'Image', 'SSIM_1', 'SSIM_2', 'SSIM_3', 'SSIM_4', 'PSNR_1',
            'PSNR_2', 'PSNR_3', 'PSNR_4', 'DiscScore_1', 'DiscScore_2',
            'DiscScore_3', 'DiscScore_4'
        ]
        self.log_eval_monitor.register(self.params_eval_monitor)

        self.print = '[%d/%d] '
        for item in self.params_eval_loss:
            self.print = self.print + item + " %.4f "
Exemplo n.º 3
0
    def __init__(self, args, model, criterion, evaluation, optimizer, writer):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.optimizer = optimizer
        self.writer = writer

        self.nepochs = args.nepochs

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method

        self.att_loss = losses.AttMatrixCov(args.cuda)
        self.dist_loss = losses.DebiasIntraDist(args.cuda)

        if self.scheduler_method is not None:
            if self.scheduler_method != 'Customer':
                self.scheduler = getattr(optim.lr_scheduler,
                                         self.scheduler_method)(
                                             self.optimizer,
                                             **args.scheduler_options)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       args.save_results)
        self.params_loss = ['LearningRate', 'idLoss', 'attLoss']
        self.log_loss.register(self.params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'LearningRate': {
                'dtype': 'running_mean'
            },
            'idLoss': {
                'dtype': 'running_mean'
            },
            'attLoss': {
                'dtype': 'running_mean'
            },
        }
        self.monitor.register(self.params_monitor)

        # display training progress
        self.print_formatter = 'Train [%d/%d][%d/%d] '
        for item in self.params_loss:
            self.print_formatter += item + " %.4f "
        # self.print_formatter += "IDLoss %.4f AttLoss %.4f"

        self.losses = {}
Exemplo n.º 4
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.device = args.device
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method
        self.scheduler_options = args.scheduler_options

        self.optimizer = getattr(optim,
                                 self.optim_method)(filter(
                                     lambda p: p.requires_grad,
                                     model.parameters()),
                                                    lr=self.lr,
                                                    **self.optim_options)
        if self.scheduler_method is not None:
            self.scheduler = getattr(optim.lr_scheduler,
                                     self.scheduler_method)(
                                         self.optimizer,
                                         **self.scheduler_options)

        # for classification
        self.labels = torch.zeros(self.batch_size,
                                  dtype=torch.long,
                                  device=self.device)
        self.inputs = torch.zeros(self.batch_size,
                                  self.resolution_high,
                                  self.resolution_wide,
                                  device=self.device)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       self.save_results)
        self.params_loss = ['Loss', 'Accuracy']
        self.log_loss.register(self.params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'Loss': {
                'dtype': 'running_mean'
            },
            'Accuracy': {
                'dtype': 'running_mean'
            }
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Train')
        self.params_visualizer = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'accuracy',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Train_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'train_image'
            },
            'Train_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'train_images'
            },
        }
        self.visualizer.register(self.params_visualizer)

        if self.log_type == 'traditional':
            # display training progress
            self.print_formatter = 'Train [%d/%d][%d/%d] '
            for item in self.params_loss:
                self.print_formatter += item + " %.4f "
        elif self.log_type == 'progressbar':
            # progress bar message formatter
            self.print_formatter = '({}/{})' \
                                   ' Load: {:.6f}s' \
                                   ' | Process: {:.3f}s' \
                                   ' | Total: {:}' \
                                   ' | ETA: {:}'
            for item in self.params_loss:
                self.print_formatter += ' | ' + item + ' {:.4f}'
            self.print_formatter += ' | lr: {:.2e}'

        self.evalmodules = []
        self.losses = {}
Exemplo n.º 5
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation

        self.nepochs = args.nepochs

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method

        self.optimizer_cnn = getattr(optim, self.optim_method)(
            model['feat'].parameters(), lr=self.lr, **self.optim_options)

        module_list = nn.ModuleList([
            criterion['id'], criterion['gender'], criterion['age'],
            criterion['race'], model['discrim']
        ])
        self.optimizer_cls = getattr(optim, self.optim_method)(
            module_list.parameters(), lr=self.lr, **self.optim_options)

        if self.scheduler_method is not None:
            if self.scheduler_method != 'Customer':
                self.scheduler = getattr(optim.lr_scheduler,
                                         self.scheduler_method)(
                                             self.optimizer_cnn,
                                             **args.scheduler_options)

        # for classification
        self.labels = torch.zeros(args.batch_size).long()
        self.inputs = torch.zeros(args.batch_size, args.resolution_high,
                                  args.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       args.save_results)
        params_loss = ['LearningRate','Loss_cls_demog', 'Loss_cls_id',\
            'Loss_conf_demog', 'Loss_conf_id', 'Loss_cls_mi', 'Loss_conf_mi']
        self.log_loss.register(params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'LearningRate': {
                'dtype': 'running_mean'
            },
            'Loss_cls_demog': {
                'dtype': 'running_mean'
            },
            'Loss_cls_id': {
                'dtype': 'running_mean'
            },
            'Loss_cls_mi': {
                'dtype': 'running_mean'
            },
            'Loss_conf_demog': {
                'dtype': 'running_mean'
            },
            'Loss_conf_id': {
                'dtype': 'running_mean'
            },
            'Loss_conf_mi': {
                'dtype': 'running_mean'
            },
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(args.port, args.env, 'Train')
        params_visualizer = {
            'LearningRate': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'learning_rate',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_demog': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_id': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_mi': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_demog': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_id': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_mi': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Train_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'train_image'
            },
            'Train_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'train_images'
            },
        }
        self.visualizer.register(params_visualizer)

        # display training progress
        self.print_formatter = 'Train [%d/%d][%d/%d] '
        for item in params_loss:
            self.print_formatter += item + " %.4f "
        # self.print_formatter += "Scale %.4f "

        self.losses = {}
        self.binage = torch.Tensor(
            [10, 22.5, 27.5, 32.5, 37.5, 42.5, 47.5, 55, 75])
Exemplo n.º 6
0
    def __init__(self, args, modelD, modelG, Encoder, criterion, prevD, prevG):

        self.args = args
        self.modelD = [modelD for i in range(2)]
        self.modelG = [modelG for i in range(2)]
        self.Encoder = Encoder
        self.prevD = prevD
        self.prevG = prevG
        self.criterion = criterion
        self.cuda = args.cuda
        self.device = torch.device("cuda" if (
            self.cuda and torch.cuda.is_available()) else "cpu")
        self.logits_loss = RankOrderLoss(self.device)
        self.logits_eval = Logits_Classification(threshold=0.5)
        self.plot_update_interval = args.plot_update_interval

        self.port = args.port
        self.env = args.env
        self.result_path = args.result_path
        self.save_path = args.save
        self.log_path = args.logs
        self.dataset_fraction = args.dataset_fraction
        self.len_dataset = 0

        self.stage_epochs = args.stage_epochs
        self.start_stage = args.start_stage
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nz = args.nz
        self.gp = args.gp
        self.gp_lambda = args.gp_lambda
        self.scheduler_patience = args.scheduler_patience
        self.scheduler_maxlen = args.scheduler_maxlen

        self.weight_gan_final = args.weight_gan_final
        self.weight_vae_init = args.weight_vae_init
        self.weight_kld = args.weight_kld
        self.margin = args.margin
        self.num_stages = args.num_stages
        self.nranks = args.nranks

        self.lr_vae = args.learning_rate_vae
        self.lr_dis = args.learning_rate_dis
        self.lr_gen = args.learning_rate_gen
        self.lr_decay = args.learning_rate_decay
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = args.optim_method
        self.vae_loss_type = args.vae_loss_type

        # for classification
        self.fixed_noise = torch.FloatTensor(self.batch_size, self.nz).normal_(
            0, 1).to(self.device)  #, volatile=True)
        self.epsilon = torch.randn(self.batch_size, self.nz).to(self.device)
        self.target_real = torch.ones(self.batch_size,
                                      self.nranks - 1).to(self.device)
        self.target_fakeD = torch.zeros(self.batch_size,
                                        self.nranks - 1).to(self.device)
        self.target_fakeG = torch.zeros(self.batch_size,
                                        self.nranks - 1).to(self.device)
        self.sigmoid = torch.sigmoid

        # Initialize optimizer
        self.optimizerE = self.initialize_optimizer(self.Encoder,
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerG = self.initialize_optimizer(self.modelG[0],
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerD = self.initialize_optimizer(self.modelD[0],
                                                    lr=self.lr_dis,
                                                    optim_method='Adam',
                                                    weight_decay=0.01 *
                                                    self.lr_dis)
        # self.schedulerE = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerE, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # self.schedulerG = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerG, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # self.schedulerD = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerD, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # Automatic scheduler
        self.schedulerE = plugins.AutomaticLRScheduler(
            self.optimizerE,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)
        self.schedulerG = plugins.AutomaticLRScheduler(
            self.optimizerG,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)
        self.schedulerD = plugins.AutomaticLRScheduler(
            self.optimizerD,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)

        # logging training
        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['Loss_D0', 'Loss_G0', 'Acc_Real', 'Acc_Fake']
        self.log_loss_train.register(self.params_loss_train)

        # monitor training
        self.monitor_train = plugins.Monitor()
        self.params_monitor_train = [
            'Loss_D0', 'Loss_G0', 'Acc_Real', 'Acc_Fake'
        ]
        self.monitor_train.register(self.params_monitor_train)

        # Define visualizer plot type for given dataset
        if args.net_type == 'gmm':
            self.plot_update_interval = 300
            if self.args.gmm_dim == 1:
                output_dtype, output_vtype = 'vector', 'histogram'
            elif self.args.gmm_dim == 2:
                output_dtype, output_vtype = 'vector', 'scatter'
        else:
            output_dtype, output_vtype = 'images', 'images'
            self.fixed_noise = self.fixed_noise.unsqueeze(-1).unsqueeze(-1)

        # visualize training
        self.visualizer_train = plugins.Visualizer(port=self.port,
                                                   env=self.env,
                                                   title='Train')
        self.params_visualizer_train = {
            'Loss_D0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 0
                }
            },
            'Loss_G0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 1
                }
            },
            'Acc_Real': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['Acc_Real', 'Acc_Fake'],
                    'id': 0
                }
            },
            'Acc_Fake': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['Acc_Real', 'Acc_Fake'],
                    'id': 1
                }
            },
            'Learning_Rate_G': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr_G'
            },
            'Learning_Rate_D': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr_D'
            },
            'Real': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'real'
            },
            'Fake': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fake'
            },
            'Fakes_Current': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'Fakes_Current'
            },
            'Fakes_Previous': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fakes_prev'
            },
        }
        self.visualizer_train.register(self.params_visualizer_train)

        # display training progress
        self.print_train = '[%d/%d][%d/%d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + item + " %.3f "

        self.giterations = 0
        self.d_iter_init = args.d_iter
        self.d_iter = self.d_iter_init
        self.g_iter_init = args.g_iter
        self.g_iter = self.g_iter_init
        print('Discriminator:', self.modelD[0])
        print('Generator:', self.modelG[0])
        print('Encoder:', self.Encoder)

        # define a zero tensor
        self.t_zero = torch.zeros(1)
        self.add_noise = args.add_noise
        self.noise_var = args.noise_var
Exemplo n.º 7
0
import sys
import os

from bottle import Bottle, route, run, static_file, url, request, get, redirect, default_app, response
from json import loads, dumps
from mixins import enable_cors
from api import api

import plugins

app = application = Bottle()
app.mount("/api/v1", api)
# app.mount("/api/v1", flames_app)

api_logger = plugins.Logger(app_name='app')
app.install(api_logger)


@app.hook('after_request')
def after_hook():
    enable_cors()


@app.get("/test")
def home():
    return {"test": True}


@app.get("/")
def home():
Exemplo n.º 8
0
    def __init__(self, args, model, criterion):
        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        # for classification
        self.label = torch.zeros(self.batch_size).long()
        self.input = torch.zeros(self.batch_size, self.nchannels,
                                 self.resolution_high, self.resolution_wide)

        if args.cuda:
            self.label = self.label.cuda()
            self.input = self.input.cuda()

        self.input = Variable(self.input, volatile=True)
        self.label = Variable(self.label, volatile=True)

        # logging testing
        self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = ['Loss', 'Accuracy']
        self.log_loss_test.register(self.params_loss_test)

        # monitor testing
        self.monitor_test = plugins.Monitor()
        self.params_monitor_test = ['Loss', 'Accuracy']
        self.monitor_test.register(self.params_monitor_test)

        # visualize testing
        self.visualizer_test = plugins.Visualizer(self.port, 'Test')
        self.params_visualizer_test = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Image': {
                'dtype': 'image',
                'vtype': 'image'
            },
            'Images': {
                'dtype': 'images',
                'vtype': 'images'
            },
        }
        self.visualizer_test.register(self.params_visualizer_test)

        # display testing progress
        self.print_test = 'Test [%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + item + " %.4f "

        self.evalmodules = []
        self.losses_test = {}
Exemplo n.º 9
0
    def __init__(self, args, modelD, modelG, Encoder, criterion):

        self.args = args
        self.modelD = [modelD for i in range(2)]
        self.modelG = [modelG for i in range(2)]
        self.Encoder = Encoder
        self.criterion = criterion

        self.port = args.port
        self.env = args.env
        self.dir_save = args.save
        self.dataset_fraction = args.dataset_fraction
        self.len_dataset = 0

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.stage_epochs = args.stage_epochs
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nz = args.nz
        self.gp = args.gp
        self.gp_lambda = args.gp_lambda
        self.scheduler_patience = args.scheduler_patience

        self.weight_gan_final = args.weight_gan_final
        self.weight_vae_init = args.weight_vae_init
        self.weight_kld = args.weight_kld
        self.margin = args.margin
        self.disc_diff_weight = args.disc_diff_weight
        self.num_stages = args.num_stages

        self.lr_vae = args.learning_rate_vae
        self.lr_dis = args.learning_rate_dis
        self.lr_gen = args.learning_rate_gen
        self.lr_decay = args.learning_rate_decay
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = args.optim_method
        self.vae_loss_type = args.vae_loss_type

        # for classification
        self.input = Variable(torch.FloatTensor(self.batch_size,
                                                self.nchannels,
                                                self.resolution_high,
                                                self.resolution_wide),
                              requires_grad=True)
        self.test_input = Variable(torch.FloatTensor(self.batch_size,
                                                     self.nchannels,
                                                     self.resolution_high,
                                                     self.resolution_wide),
                                   volatile=True)
        self.fixed_noise = Variable(torch.FloatTensor(self.batch_size,
                                                      self.nz).normal_(0, 1),
                                    volatile=True)
        self.epsilon = Variable(torch.randn(self.batch_size, self.nz),
                                requires_grad=False)

        if args.cuda:
            self.input = self.input.cuda()
            self.test_input = self.test_input.cuda()
            self.fixed_noise = self.fixed_noise.cuda()
            self.epsilon = self.epsilon.cuda()

        # Initialize optimizer
        self.optimizerE = self.initialize_optimizer(self.Encoder,
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerG = self.initialize_optimizer(self.modelG[0],
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerD = self.initialize_optimizer(self.modelD[0],
                                                    lr=self.lr_dis,
                                                    optim_method='Adam',
                                                    weight_decay=100)
        self.schedulerE = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizerE, factor=self.lr_decay, min_lr=1e-3 * self.lr_vae)
        self.schedulerG = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizerG, factor=self.lr_decay, min_lr=1e-3 * self.lr_vae)
        self.schedulerD = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizerD, factor=self.lr_decay, min_lr=1e-3 * self.lr_vae)

        # logging training
        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = [
            'Loss_D0', 'Loss_G0', 'MSE', 'KLD', 'Score_D0', 'Score_D1',
            'Score_D0_G0', 'Score_D0_G1', 'Score_D1_G1', 'Disc_Difference'
        ]
        self.log_loss_train.register(self.params_loss_train)

        self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = [
            'SSIM', 'PSNR', 'Test_Score_D0', 'Test_Score_D0_G0',
            'Test_Score_D0_G1', 'Test_Disc_Accuracy'
        ]
        self.log_loss_test.register(self.params_loss_test)

        # monitor training
        self.monitor_train = plugins.Monitor()
        self.params_monitor_train = [
            'Loss_D0', 'Loss_G0', 'MSE', 'KLD', 'Score_D0', 'Score_D1',
            'Score_D0_G0', 'Score_D0_G1', 'Score_D1_G1', 'Disc_Difference'
        ]
        self.monitor_train.register(self.params_monitor_train)

        self.monitor_test = plugins.Monitor()
        self.params_monitor_test = [
            'SSIM', 'PSNR', 'Test_Score_D0', 'Test_Score_D0_G0',
            'Test_Score_D0_G1', 'Test_Disc_Accuracy'
        ]
        self.monitor_test.register(self.params_monitor_test)

        # Define visualizer plot type for given dataset
        if args.net_type == 'gmm':
            if self.args.gmm_dim == 1:
                output_dtype, output_vtype = 'vector', 'histogram'
            elif self.args.gmm_dim == 2:
                output_dtype, output_vtype = 'heatmap', 'heatmap'
        else:
            output_dtype, output_vtype = 'images', 'images'
            self.fixed_noise = self.fixed_noise.unsqueeze(-1).unsqueeze(-1)

        # visualize training
        self.visualizer_train = plugins.HourGlassVisualizer(port=self.port,
                                                            env=self.env,
                                                            title='Train')
        self.params_visualizer_train = {
            'Loss_D0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 0
                }
            },
            'Loss_G0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 1
                }
            },
            'MSE': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'enc_losses',
                'layout': {
                    'windows': ['MSE', 'KLD'],
                    'id': 0
                }
            },
            'KLD': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'enc_losses',
                'layout': {
                    'windows': ['MSE', 'KLD'],
                    'id': 1
                }
            },
            'Mean': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'norm_params',
                'layout': {
                    'windows': ['Mean', 'Sigma'],
                    'id': 0
                }
            },
            'Sigma': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'norm_params',
                'layout': {
                    'windows': ['Mean', 'Sigma'],
                    'id': 1
                }
            },
            'Score_D0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_D0',
                'layout': {
                    'windows': ['Score_D0', 'Score_D0_G0', 'Score_D0_G1'],
                    'id': 0
                }
            },
            'Score_D0_G0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_D0',
                'layout': {
                    'windows': ['Score_D0', 'Score_D0_G0', 'Score_D0_G1'],
                    'id': 1
                }
            },
            'Score_D0_G1': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_D0',
                'layout': {
                    'windows': ['Score_D0', 'Score_D0_G0', 'Score_D0_G1'],
                    'id': 2
                }
            },
            'Score_D1': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_D1',
                'layout': {
                    'windows': ['Score_D1', 'Score_D1_G1'],
                    'id': 0
                }
            },
            'Score_D1_G1': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_D1',
                'layout': {
                    'windows': ['Score_D1', 'Score_D1_G1'],
                    'id': 1
                }
            },
            'netD_norm_w': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'netd_norm'
            },
            'LR_E': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr',
                'layout': {
                    'windows': ['LR_E', 'LR_G', 'LR_D'],
                    'id': 0
                }
            },
            'LR_G': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr',
                'layout': {
                    'windows': ['LR_E', 'LR_G', 'LR_D'],
                    'id': 1
                }
            },
            'LR_D': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr',
                'layout': {
                    'windows': ['LR_E', 'LR_G', 'LR_D'],
                    'id': 2
                }
            },
            'Gradient_Penalty': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'gp'
            },
            'Disc_Difference': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'disc_diff'
            },
            'Disc_Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'disc_acc'
            },
            'Real': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'real'
            },
            'Fakes_Encoder': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fakes_enc'
            },
            'Fakes_Normal': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fakes_normal'
            },
        }
        self.visualizer_train.register(self.params_visualizer_train)

        self.visualizer_test = plugins.HourGlassVisualizer(port=self.port,
                                                           env=self.env,
                                                           title='Test')
        self.params_visualizer_test = {
            'Test_Score_D0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'test_disc_scores',
                'layout': {
                    'windows':
                    ['Test_Score_D0', 'Test_Score_D0_G0', 'Test_Score_D0_G1'],
                    'id':
                    0
                }
            },
            'Test_Score_D0_G0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'test_disc_scores',
                'layout': {
                    'windows':
                    ['Test_Score_D0', 'Test_Score_D0_G0', 'Test_Score_D0_G1'],
                    'id':
                    1
                }
            },
            'Test_Score_D0_G1': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'test_disc_scores',
                'layout': {
                    'windows':
                    ['Test_Score_D0', 'Test_Score_D0_G0', 'Test_Score_D0_G1'],
                    'id':
                    2
                }
            },
            'SSIM': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'ssim_scores'
            },
            'PSNR': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'psnr_scores'
            },
            'Test_Real': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'real_test'
            },
            'Test_Fakes_Encoder': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fakes_enc_test'
            },
            'Test_Disc_Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'test_disc_acc'
            },
        }
        self.visualizer_test.register(self.params_visualizer_test)

        # display training progress
        self.print_train = '[%d/%d][%d/%d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + item + " %.3f "

        self.print_test = '[%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + item + " %.3f "

        self.giterations = 0
        self.d_iter = args.d_iter
        self.g_iter = args.g_iter
        self.clamp_lower = args.clamp_lower
        self.clamp_upper = args.clamp_upper
        print('Discriminator:', self.modelD[0])
        print('Generator:', self.modelG[0])
        print('Encoder:', self.Encoder)

        # define a zero tensor
        self.t_zero = Variable(torch.zeros(1))
Exemplo n.º 10
0
        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs, volatile=True)
        self.labels = Variable(self.labels, volatile=True)

        # logging testing
<<<<<<< HEAD
        # self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = ['Loss', 'Accuracy']
        # self.log_loss_test.register(self.params_loss_test)
=======
        self.log_loss = plugins.Logger(
            args.logs_dir,
            'TestLogger.txt',
            self.save_results
        )
        self.params_loss = ['Loss', 'Accuracy']
        self.log_loss.register(self.params_loss)
>>>>>>> original/master

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'Loss': {'dtype': 'running_mean'},
            'Accuracy': {'dtype': 'running_mean'}
        }
        self.monitor.register(self.params_monitor)

        # visualize testing
Exemplo n.º 11
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        # for classification
        self.labels = torch.zeros(self.batch_size).long()
        self.inputs = torch.zeros(self.batch_size, self.resolution_high,
                                  self.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs, volatile=True)
        self.labels = Variable(self.labels, volatile=True)

        # logging testing
        self.log_loss = plugins.Logger(args.logs_dir, 'TestLogger.txt',
                                       self.save_results)
        self.params_loss = ['Loss', 'Accuracy']
        self.log_loss.register(self.params_loss)

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'Loss': {
                'dtype': 'running_mean'
            },
            'Accuracy': {
                'dtype': 'running_mean'
            }
        }
        self.monitor.register(self.params_monitor)

        # visualize testing
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Test')
        self.params_visualizer = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'accuracy',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            'Test_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'test_image'
            },
            'Test_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'test_images'
            },
        }
        self.visualizer.register(self.params_visualizer)

        if self.log_type == 'traditional':
            # display training progress
            self.print_formatter = 'Test [%d/%d][%d/%d] '
            for item in self.params_loss:
                self.print_formatter += item + " %.4f "
        elif self.log_type == 'progressbar':
            # progress bar message formatter
            self.print_formatter = '({}/{})' \
                                   ' Load: {:.6f}s' \
                                   ' | Process: {:.3f}s' \
                                   ' | Total: {:}' \
                                   ' | ETA: {:}'
            for item in self.params_loss:
                self.print_formatter += ' | ' + item + ' {:.4f}'

        self.evalmodules = []
        self.losses = {}
Exemplo n.º 12
0
    def __init__(self, args, model, criterion):
        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.ngpu = args.ngpu
        self.nepochs = args.nepochs
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = args.optim_method

        if self.optim_method == 'Adam':
            self.optimizer = optim.Adam(
                model.parameters(),
                lr=self.lr,
                weight_decay=self.weight_decay,
                betas=(self.adam_beta1, self.adam_beta2),
            )
        elif self.optim_method == 'RMSprop':
            self.optimizer = optim.RMSprop(
                model.parameters(),
                lr=self.lr,
                weight_decay=self.weight_decay,
            )
        elif self.optim_method == 'SGD':
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=self.lr,
                                       weight_decay=self.weight_decay,
                                       momentum=self.momentum,
                                       nesterov=True)
        else:
            raise (Exception("Unknown Optimization Method"))

        # for classification
        self.label = torch.zeros(self.batch_size).long()
        self.input = torch.zeros(self.batch_size, self.nchannels,
                                 self.resolution_high, self.resolution_wide)

        if args.cuda:
            self.label = self.label.cuda()
            self.input = self.input.cuda()

        self.input = Variable(self.input)
        self.label = Variable(self.label)

        # logging training
        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['Loss', 'Accuracy']
        self.log_loss_train.register(self.params_loss_train)

        # monitor training
        self.monitor_train = plugins.Monitor()
        self.params_monitor_train = ['Loss', 'Accuracy']
        self.monitor_train.register(self.params_monitor_train)

        # visualize training
        self.visualizer_train = plugins.Visualizer(self.port, 'Train')
        self.params_visualizer_train = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Image': {
                'dtype': 'image',
                'vtype': 'image'
            },
            'Images': {
                'dtype': 'images',
                'vtype': 'images'
            },
        }
        self.visualizer_train.register(self.params_visualizer_train)

        # display training progress
        self.print_train = 'Train [%d/%d][%d/%d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + item + " %.4f "

        self.evalmodules = []
        self.losses_train = {}
        print(self.model)
Exemplo n.º 13
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation

        self.nepochs = args.nepochs

        # for classification
        self.labels = torch.zeros(args.batch_size).long()
        self.inputs = torch.zeros(args.batch_size, args.resolution_high,
                                  args.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging testing
        self.log_loss = plugins.Logger(args.logs_dir, 'TestLogger.txt',
                                       args.save_results)
        params_loss = ['ACC']
        self.log_loss.register(params_loss)

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'ACC': {
                'dtype': 'running_mean'
            },
        }
        self.monitor.register(self.params_monitor)

        # visualize testing
        self.visualizer = plugins.Visualizer(args.port, args.env, 'Test')
        params_visualizer = {
            'ACC': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            # 'Test_Image': {'dtype': 'image', 'vtype': 'image',
            #                'win': 'test_image'},
            # 'Test_Images': {'dtype': 'images', 'vtype': 'images',
            #                 'win': 'test_images'},
        }
        self.visualizer.register(params_visualizer)

        # display training progress
        self.print_formatter = 'Test [%d/%d]] '
        for item in ['ACC']:
            self.print_formatter += item + " %.4f "

        self.losses = {}
        # self.binage = torch.Tensor([10,22.5,27.5,32.5,37.5,42.5,47.5,55,75])
        # self.binage = torch.Tensor([10,25,35,45,55,75])
        self.binage = torch.Tensor([19, 37.5, 52.5, 77])
Exemplo n.º 14
0
    def __init__(self, args, model, criterion):

        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        if args.nepochs is None:
            self.nepochs = 0
        self.nchannels = args.ngchannels
        self.nechannels = args.nechannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nlatent = args.nlatent
        self.out_steps = args.out_steps

        self.learning_rate = args.learning_rate
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = getattr(optim, args.optim_method)

        self.optimizer = {}
        self.optimizer["netG"] = self.optim_method(model["netG"].parameters(),
                                                   lr=self.learning_rate)
        self.composition = torch.FloatTensor(self.batch_size, self.nchannels,
                                             self.resolution_high,
                                             self.resolution_wide)
        self.metadata = torch.FloatTensor(self.batch_size, self.nechannels,
                                          self.resolution_high,
                                          self.resolution_wide)
        if args.out_images is None:
            self.out_images = self.batch_size + 8 - (self.batch_size % 8)
        else:
            self.out_images = args.out_images

        if args.cuda:
            self.composition = self.composition.cuda()
            self.metadata = self.metadata.cuda()

        self.composition = Variable(self.composition)
        self.metadata = Variable(self.metadata)

        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['L2']
        self.log_loss_train.register(self.params_loss_train)

        self.log_monitor_train = plugins.Monitor(smoothing=False)
        self.params_monitor_train = ['L2']
        self.log_monitor_train.register(self.params_monitor_train)

        self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = ['L2']
        self.log_loss_test.register(self.params_loss_test)

        self.log_monitor_test = plugins.Monitor(smoothing=False)
        self.params_monitor_test = ['L2']
        self.log_monitor_test.register(self.params_monitor_test)

        # visualize training
        self.visualizer_train = plugins.Visualizer(self.port, 'Train',
                                                   args.images, args.pre_name)

        self.params_visualizer_train = {
            'L2': {
                'dtype': 'scalar',
                'vtype': 'plot'
            }
        }

        self.visualizer_train.register(self.params_visualizer_train)

        # visualize testing
        self.visualizer_test = plugins.Visualizer(self.port, 'Test',
                                                  args.images, args.pre_name)

        self.params_visualizer_test = {
            'L2': {
                'dtype': 'scalar',
                'vtype': 'plot'
            }
        }

        self.visualizer_test.register(self.params_visualizer_test)
        self.imgio_test = plugins.ImageIO(args.images, args.pre_name)

        # display training progress
        self.print_train = '[%04d/%04d][%02d/%02d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + f"{item:4}" + " %8.6f "

        # display testing progress
        self.print_test = '[%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + f"{item:4}" + " %8.6f "

        self.giterations = 1
        self.losses_train = {}
        self.losses_test = {}
        self.mu_sigma = {}
Exemplo n.º 15
0
    def __init__(self, args, netD, netG, netE):
        self.netD = netD
        self.netG = netG
        self.netE = netE
        for p in self.netD.parameters():
            p.requires_grad = False
        for p in self.netG.parameters():
            p.requires_grad = False
        self.cuda = args.cuda
        self.device = torch.device("cuda" if (
            self.cuda and torch.cuda.is_available()) else "cpu")

        self.args = args
        self.nchannels = args.nchannels
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nz = args.nz
        self.wcom = args.disc_loss_weight
        self.ssim_weight = args.ssim_weight
        self.cuda = args.cuda
        self.citers = args.citers
        self.lr = args.learning_rate_vae
        self.momentum = args.momentum
        self.batch_size = args.batch_size
        self.use_encoder = args.use_encoder
        self.start_index = args.start_index

        self.noise = torch.FloatTensor(self.batch_size, self.nz, 1,
                                       1).normal_(0, 1).to(self.device)
        self.noise.requires_grad = True
        self.epsilon = torch.randn(self.batch_size, self.nz).to(self.device)

        self.optimizerC = optim.RMSprop([self.noise], lr=self.lr)
        self.scheduler = plugins.AutomaticLRScheduler(
            self.optimizerC,
            maxlen=500,
            factor=0.1,
            patience=self.args.scheduler_patience)
        self.log_eval_loss = plugins.Logger(args.logs, 'CompletionLog.txt')
        self.params_eval_loss = [
            'Image', 'Input_SSIM', 'Input_PSNR', 'SSIM', 'PSNR', 'C_Loss',
            'P_Loss'
        ]
        self.log_eval_loss.register(self.params_eval_loss)

        # Create the mask
        # self.pmask = torch.ones(self.batch_size, self.nchannels, self.resolution_high, self.args.resolution_wide)
        # if args.mask_type == 'central':
        #     self.l = int(self.resolution_high*self.args.scale)
        #     self.u = int(self.resolution_wide*(1-self.args.scale))
        #     if self.l != self.u:
        #         self.pmask[:, :, 5+self.l:5+self.u, self.l:self.u] = 0.0
        # elif args.mask_type == 'periocular':
        #     self.pmask[:,:,int(0.4*self.resolution_high):,:] = 0.0
        #     self.pmask[:,:,:,:8] = 0.0
        #     self.pmask[:,:,:,56:] = 0.0
        # self.nmask = torch.add(-self.pmask, 1)
        self.pmask = torch.ones(self.resolution_high,
                                self.args.resolution_wide).to(self.device)
        if args.mask_type == 'central':
            self.l = int(self.resolution_high * self.args.scale)
            self.u = int(self.resolution_wide * (1 - self.args.scale))
            if self.l != self.u:
                self.pmask[5 + self.l:5 + self.u, self.l:self.u] = 0.0
        elif args.mask_type == 'periocular':
            self.pmask[int(0.4 * self.resolution_high):, :] = 0.0
            # self.pmask[int(0.4*self.resolution_high):int(0.9*self.resolution_high),8:56] = 0.0
            if self.args.scale == 0.3:
                self.pmask[:, :8] = 0.0
                self.pmask[:, 56:] = 0.0
        self.nmask = torch.add(-self.pmask, 1)
        self.non_mask_pixels = (self.nmask.view(-1) == 0).nonzero().squeeze()

        # create coefficient matrix
        self.num_pixels = self.resolution_high * self.resolution_wide
        A = scipy.sparse.identity(self.num_pixels, format='lil')
        for y in range(self.resolution_high):
            for x in range(self.resolution_wide):
                if self.nmask[y, x]:
                    index = x + y * self.resolution_wide
                    A[index, index] = 4
                    if index + 1 < self.num_pixels:
                        A[index, index + 1] = -1
                    if index - 1 >= 0:
                        A[index, index - 1] = -1
                    if index + self.resolution_wide < self.num_pixels:
                        A[index, index + self.resolution_wide] = -1
                    if index - self.resolution_wide >= 0:
                        A[index, index - self.resolution_wide] = -1
        A = torch.Tensor(A.toarray())
        # self.A = Variable(A)
        self.Ainv = torch.inverse(A).to(self.device)

        # Construct Poisson Matrix for Blending
        P = 4 * torch.eye(self.num_pixels)
        diag = np.arange(0, self.num_pixels - 1)
        P[diag, diag + 1] = -1
        P[diag + 1, diag] = -1
        diag = np.arange(self.resolution_high, self.num_pixels)
        P[diag - self.resolution_high, diag] = -1
        P[diag, diag - self.resolution_high] = -1
        self.P = P.to(self.device)

        self.criterion = nn.L1Loss().to(self.device)
        self.ssim_loss = pytorch_ssim.SSIM()

        self.blend = args.blend

        self.losses = {}

        self.log_eval_monitor = plugins.Monitor()
        self.params_eval_monitor = [
            'Image', 'Input_SSIM', 'Input_PSNR', 'SSIM', 'PSNR', 'C_Loss',
            'P_Loss'
        ]
        self.log_eval_monitor.register(self.params_eval_monitor)

        self.print = '[%d/%d] [%d/%d]'
        for item in self.params_eval_loss:
            if item == 'Image':
                self.print = self.print + item + " %d "
            else:
                self.print = self.print + item + " %.4f "

        # self.visualizer = plugins.Visualizer(port=self.args.port, env=self.args.env, title='Image Completion')
        # self.visualizer_dict = {
        # 'Recon_SSIM': {'dtype':'scalar', 'vtype': 'plot', 'win': 'ssim'},
        # 'Recon_PSNR': {'dtype':'scalar', 'vtype': 'plot', 'win': 'psnr'},
        # 'Z_Norm': {'dtype':'scalar', 'vtype': 'plot', 'win': 'z_norm'},
        # 'Contextual_Loss': {'dtype':'scalar', 'vtype': 'plot', 'win': 'Contextual_Loss'},
        # 'Perceptual_Loss': {'dtype':'scalar', 'vtype': 'plot', 'win': 'Perceptual_Loss'},
        # # 'SSIM_Loss': {'dtype':'scalar', 'vtype': 'plot', 'win': 'SSIM_Loss'},
        # 'LR': {'dtype':'scalar', 'vtype': 'plot', 'win': 'lr'},
        # 'Original_Image': {'dtype':'images', 'vtype': 'images', 'win': 'input'},
        # 'Occluded_Image': {'dtype':'images', 'vtype': 'images', 'win': 'input_real'},
        # 'Fake_Image': {'dtype':'images', 'vtype': 'images', 'win': 'fake'},
        # 'Completed_Image': {'dtype':'images', 'vtype': 'images', 'win': 'completed'},
        # }
        # self.visualizer.register(self.visualizer_dict)

        self.c_loss = 0
        self.d_loss = 0
        self.disc_type = args.disc_type
Exemplo n.º 16
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method
        self.scheduler_options = args.scheduler_options
        self.weight_loss = args.weight_loss

        # setting optimizer for multiple modules
        self.module_list = nn.ModuleList()
        for key in list(self.model):
            self.module_list.append(self.model[key])
        self.optimizer = getattr(optim, self.optim_method)(
            self.module_list.parameters(), lr=self.lr, **self.optim_options)
        if self.scheduler_method is not None:
            self.scheduler = getattr(optim.lr_scheduler,
                                     self.scheduler_method)(
                                         self.optimizer,
                                         **self.scheduler_options)

        # for classification
        self.labels = torch.zeros(self.batch_size).long()
        self.inputs = torch.zeros(self.batch_size, self.resolution_high,
                                  self.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       self.save_results)
        self.params_loss = ['LearningRate', 'Loss']  #, 'TAR']
        self.log_loss.register(self.params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'LearningRate': {
                'dtype': 'running_mean'
            },
            'Loss': {
                'dtype': 'running_mean'
            },
            # 'TAR': {'dtype': 'running_mean'},
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Train')
        self.params_visualizer = {
            'LearningRate': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'learning_rate',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            # 'TAR': {'dtype': 'scalar', 'vtype': 'plot', 'win': 'mAP',
            #         'layout': {'windows': ['train', 'test'], 'id': 0}},
        }
        self.visualizer.register(self.params_visualizer)

        # display training progress
        self.print_formatter = 'Train [%d/%d][%d/%d] '
        for item in self.params_loss:
            self.print_formatter += item + " %.4f "

        self.evalmodules = []
        self.losses = {}