Beispiel #1
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.device = args.device
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method
        self.scheduler_options = args.scheduler_options

        self.optimizer = getattr(optim,
                                 self.optim_method)(filter(
                                     lambda p: p.requires_grad,
                                     model.parameters()),
                                                    lr=self.lr,
                                                    **self.optim_options)
        if self.scheduler_method is not None:
            self.scheduler = getattr(optim.lr_scheduler,
                                     self.scheduler_method)(
                                         self.optimizer,
                                         **self.scheduler_options)

        # for classification
        self.labels = torch.zeros(self.batch_size,
                                  dtype=torch.long,
                                  device=self.device)
        self.inputs = torch.zeros(self.batch_size,
                                  self.resolution_high,
                                  self.resolution_wide,
                                  device=self.device)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       self.save_results)
        self.params_loss = ['Loss', 'Accuracy']
        self.log_loss.register(self.params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'Loss': {
                'dtype': 'running_mean'
            },
            'Accuracy': {
                'dtype': 'running_mean'
            }
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Train')
        self.params_visualizer = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'accuracy',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Train_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'train_image'
            },
            'Train_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'train_images'
            },
        }
        self.visualizer.register(self.params_visualizer)

        if self.log_type == 'traditional':
            # display training progress
            self.print_formatter = 'Train [%d/%d][%d/%d] '
            for item in self.params_loss:
                self.print_formatter += item + " %.4f "
        elif self.log_type == 'progressbar':
            # progress bar message formatter
            self.print_formatter = '({}/{})' \
                                   ' Load: {:.6f}s' \
                                   ' | Process: {:.3f}s' \
                                   ' | Total: {:}' \
                                   ' | ETA: {:}'
            for item in self.params_loss:
                self.print_formatter += ' | ' + item + ' {:.4f}'
            self.print_formatter += ' | lr: {:.2e}'

        self.evalmodules = []
        self.losses = {}
    def __init__(self, args, modelD, modelG, Encoder, criterion, prevD, prevG):

        self.args = args
        self.modelD = [modelD for i in range(2)]
        self.modelG = [modelG for i in range(2)]
        self.Encoder = Encoder
        self.prevD = prevD
        self.prevG = prevG
        self.criterion = criterion
        self.cuda = args.cuda
        self.device = torch.device("cuda" if (
            self.cuda and torch.cuda.is_available()) else "cpu")
        self.logits_loss = RankOrderLoss(self.device)
        self.logits_eval = Logits_Classification(threshold=0.5)
        self.plot_update_interval = args.plot_update_interval

        self.port = args.port
        self.env = args.env
        self.result_path = args.result_path
        self.save_path = args.save
        self.log_path = args.logs
        self.dataset_fraction = args.dataset_fraction
        self.len_dataset = 0

        self.stage_epochs = args.stage_epochs
        self.start_stage = args.start_stage
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nz = args.nz
        self.gp = args.gp
        self.gp_lambda = args.gp_lambda
        self.scheduler_patience = args.scheduler_patience
        self.scheduler_maxlen = args.scheduler_maxlen

        self.weight_gan_final = args.weight_gan_final
        self.weight_vae_init = args.weight_vae_init
        self.weight_kld = args.weight_kld
        self.margin = args.margin
        self.num_stages = args.num_stages
        self.nranks = args.nranks

        self.lr_vae = args.learning_rate_vae
        self.lr_dis = args.learning_rate_dis
        self.lr_gen = args.learning_rate_gen
        self.lr_decay = args.learning_rate_decay
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = args.optim_method
        self.vae_loss_type = args.vae_loss_type

        # for classification
        self.fixed_noise = torch.FloatTensor(self.batch_size, self.nz).normal_(
            0, 1).to(self.device)  #, volatile=True)
        self.epsilon = torch.randn(self.batch_size, self.nz).to(self.device)
        self.target_real = torch.ones(self.batch_size,
                                      self.nranks - 1).to(self.device)
        self.target_fakeD = torch.zeros(self.batch_size,
                                        self.nranks - 1).to(self.device)
        self.target_fakeG = torch.zeros(self.batch_size,
                                        self.nranks - 1).to(self.device)
        self.sigmoid = torch.sigmoid

        # Initialize optimizer
        self.optimizerE = self.initialize_optimizer(self.Encoder,
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerG = self.initialize_optimizer(self.modelG[0],
                                                    lr=self.lr_vae,
                                                    optim_method='Adam')
        self.optimizerD = self.initialize_optimizer(self.modelD[0],
                                                    lr=self.lr_dis,
                                                    optim_method='Adam',
                                                    weight_decay=0.01 *
                                                    self.lr_dis)
        # self.schedulerE = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerE, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # self.schedulerG = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerG, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # self.schedulerD = optim.lr_scheduler.ReduceLROnPlateau(self.optimizerD, factor=self.lr_decay, patience=self.scheduler_patience, min_lr=1e-3*self.lr_vae)
        # Automatic scheduler
        self.schedulerE = plugins.AutomaticLRScheduler(
            self.optimizerE,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)
        self.schedulerG = plugins.AutomaticLRScheduler(
            self.optimizerG,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)
        self.schedulerD = plugins.AutomaticLRScheduler(
            self.optimizerD,
            maxlen=self.scheduler_maxlen,
            factor=self.lr_decay,
            patience=self.scheduler_patience)

        # logging training
        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['Loss_D0', 'Loss_G0', 'Acc_Real', 'Acc_Fake']
        self.log_loss_train.register(self.params_loss_train)

        # monitor training
        self.monitor_train = plugins.Monitor()
        self.params_monitor_train = [
            'Loss_D0', 'Loss_G0', 'Acc_Real', 'Acc_Fake'
        ]
        self.monitor_train.register(self.params_monitor_train)

        # Define visualizer plot type for given dataset
        if args.net_type == 'gmm':
            self.plot_update_interval = 300
            if self.args.gmm_dim == 1:
                output_dtype, output_vtype = 'vector', 'histogram'
            elif self.args.gmm_dim == 2:
                output_dtype, output_vtype = 'vector', 'scatter'
        else:
            output_dtype, output_vtype = 'images', 'images'
            self.fixed_noise = self.fixed_noise.unsqueeze(-1).unsqueeze(-1)

        # visualize training
        self.visualizer_train = plugins.Visualizer(port=self.port,
                                                   env=self.env,
                                                   title='Train')
        self.params_visualizer_train = {
            'Loss_D0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 0
                }
            },
            'Loss_G0': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_gan',
                'layout': {
                    'windows': ['Loss_D0', 'Loss_G0'],
                    'id': 1
                }
            },
            'Acc_Real': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['Acc_Real', 'Acc_Fake'],
                    'id': 0
                }
            },
            'Acc_Fake': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['Acc_Real', 'Acc_Fake'],
                    'id': 1
                }
            },
            'Learning_Rate_G': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr_G'
            },
            'Learning_Rate_D': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'lr_D'
            },
            'Real': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'real'
            },
            'Fake': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fake'
            },
            'Fakes_Current': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'Fakes_Current'
            },
            'Fakes_Previous': {
                'dtype': output_dtype,
                'vtype': output_vtype,
                'win': 'fakes_prev'
            },
        }
        self.visualizer_train.register(self.params_visualizer_train)

        # display training progress
        self.print_train = '[%d/%d][%d/%d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + item + " %.3f "

        self.giterations = 0
        self.d_iter_init = args.d_iter
        self.d_iter = self.d_iter_init
        self.g_iter_init = args.g_iter
        self.g_iter = self.g_iter_init
        print('Discriminator:', self.modelD[0])
        print('Generator:', self.modelG[0])
        print('Encoder:', self.Encoder)

        # define a zero tensor
        self.t_zero = torch.zeros(1)
        self.add_noise = args.add_noise
        self.noise_var = args.noise_var
Beispiel #3
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation

        self.nepochs = args.nepochs

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method

        self.optimizer_cnn = getattr(optim, self.optim_method)(
            model['feat'].parameters(), lr=self.lr, **self.optim_options)

        module_list = nn.ModuleList([
            criterion['id'], criterion['gender'], criterion['age'],
            criterion['race'], model['discrim']
        ])
        self.optimizer_cls = getattr(optim, self.optim_method)(
            module_list.parameters(), lr=self.lr, **self.optim_options)

        if self.scheduler_method is not None:
            if self.scheduler_method != 'Customer':
                self.scheduler = getattr(optim.lr_scheduler,
                                         self.scheduler_method)(
                                             self.optimizer_cnn,
                                             **args.scheduler_options)

        # for classification
        self.labels = torch.zeros(args.batch_size).long()
        self.inputs = torch.zeros(args.batch_size, args.resolution_high,
                                  args.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       args.save_results)
        params_loss = ['LearningRate','Loss_cls_demog', 'Loss_cls_id',\
            'Loss_conf_demog', 'Loss_conf_id', 'Loss_cls_mi', 'Loss_conf_mi']
        self.log_loss.register(params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'LearningRate': {
                'dtype': 'running_mean'
            },
            'Loss_cls_demog': {
                'dtype': 'running_mean'
            },
            'Loss_cls_id': {
                'dtype': 'running_mean'
            },
            'Loss_cls_mi': {
                'dtype': 'running_mean'
            },
            'Loss_conf_demog': {
                'dtype': 'running_mean'
            },
            'Loss_conf_id': {
                'dtype': 'running_mean'
            },
            'Loss_conf_mi': {
                'dtype': 'running_mean'
            },
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(args.port, args.env, 'Train')
        params_visualizer = {
            'LearningRate': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'learning_rate',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_demog': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_id': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_cls_mi': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_cls',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_demog': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_id': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss_conf_mi': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss_conf',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Train_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'train_image'
            },
            'Train_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'train_images'
            },
        }
        self.visualizer.register(params_visualizer)

        # display training progress
        self.print_formatter = 'Train [%d/%d][%d/%d] '
        for item in params_loss:
            self.print_formatter += item + " %.4f "
        # self.print_formatter += "Scale %.4f "

        self.losses = {}
        self.binage = torch.Tensor(
            [10, 22.5, 27.5, 32.5, 37.5, 42.5, 47.5, 55, 75])
Beispiel #4
0
<<<<<<< HEAD
        # self.visualizer_test = plugins.Visualizer(self.port, 'Test')
        # self.params_visualizer_test = {
        #     'Loss': {'dtype': 'scalar', 'vtype': 'plot'},
        #     'Accuracy': {'dtype': 'scalar', 'vtype': 'plot'},
        #     'Image': {'dtype': 'image', 'vtype': 'image'},
        #     'Images': {'dtype': 'images', 'vtype': 'images'},
        # }
        # self.visualizer_test.register(self.params_visualizer_test)

        # display testing progress
        self.print_test = 'Test [%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + item + " %.4f "
=======
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Test')
        self.params_visualizer = {
            'Loss': {'dtype': 'scalar', 'vtype': 'plot', 'win': 'loss',
                     'layout': {'windows': ['train', 'test'], 'id': 1}},
            'Accuracy': {'dtype': 'scalar', 'vtype': 'plot', 'win': 'accuracy',
                         'layout': {'windows': ['train', 'test'], 'id': 1}},
            'Test_Image': {'dtype': 'image', 'vtype': 'image',
                           'win': 'test_image'},
            'Test_Images': {'dtype': 'images', 'vtype': 'images',
                            'win': 'test_images'},
        }
        self.visualizer.register(self.params_visualizer)

        if self.log_type == 'traditional':
            # display training progress
            self.print_formatter = 'Test [%d/%d][%d/%d] '
Beispiel #5
0
    def __init__(self, args, model, criterion):
        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        # for classification
        self.label = torch.zeros(self.batch_size).long()
        self.input = torch.zeros(self.batch_size, self.nchannels,
                                 self.resolution_high, self.resolution_wide)

        if args.cuda:
            self.label = self.label.cuda()
            self.input = self.input.cuda()

        self.input = Variable(self.input, volatile=True)
        self.label = Variable(self.label, volatile=True)

        # logging testing
        self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = ['Loss', 'Accuracy']
        self.log_loss_test.register(self.params_loss_test)

        # monitor testing
        self.monitor_test = plugins.Monitor()
        self.params_monitor_test = ['Loss', 'Accuracy']
        self.monitor_test.register(self.params_monitor_test)

        # visualize testing
        self.visualizer_test = plugins.Visualizer(self.port, 'Test')
        self.params_visualizer_test = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Image': {
                'dtype': 'image',
                'vtype': 'image'
            },
            'Images': {
                'dtype': 'images',
                'vtype': 'images'
            },
        }
        self.visualizer_test.register(self.params_visualizer_test)

        # display testing progress
        self.print_test = 'Test [%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + item + " %.4f "

        self.evalmodules = []
        self.losses_test = {}
Beispiel #6
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        # for classification
        self.labels = torch.zeros(self.batch_size).long()
        self.inputs = torch.zeros(self.batch_size, self.resolution_high,
                                  self.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs, volatile=True)
        self.labels = Variable(self.labels, volatile=True)

        # logging testing
        self.log_loss = plugins.Logger(args.logs_dir, 'TestLogger.txt',
                                       self.save_results)
        self.params_loss = ['Loss', 'Accuracy']
        self.log_loss.register(self.params_loss)

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'Loss': {
                'dtype': 'running_mean'
            },
            'Accuracy': {
                'dtype': 'running_mean'
            }
        }
        self.monitor.register(self.params_monitor)

        # visualize testing
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Test')
        self.params_visualizer = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'accuracy',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            'Test_Image': {
                'dtype': 'image',
                'vtype': 'image',
                'win': 'test_image'
            },
            'Test_Images': {
                'dtype': 'images',
                'vtype': 'images',
                'win': 'test_images'
            },
        }
        self.visualizer.register(self.params_visualizer)

        if self.log_type == 'traditional':
            # display training progress
            self.print_formatter = 'Test [%d/%d][%d/%d] '
            for item in self.params_loss:
                self.print_formatter += item + " %.4f "
        elif self.log_type == 'progressbar':
            # progress bar message formatter
            self.print_formatter = '({}/{})' \
                                   ' Load: {:.6f}s' \
                                   ' | Process: {:.3f}s' \
                                   ' | Total: {:}' \
                                   ' | ETA: {:}'
            for item in self.params_loss:
                self.print_formatter += ' | ' + item + ' {:.4f}'

        self.evalmodules = []
        self.losses = {}
Beispiel #7
0
    def __init__(self, args, model, criterion):
        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.ngpu = args.ngpu
        self.nepochs = args.nepochs
        self.nchannels = args.nchannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = args.optim_method

        if self.optim_method == 'Adam':
            self.optimizer = optim.Adam(
                model.parameters(),
                lr=self.lr,
                weight_decay=self.weight_decay,
                betas=(self.adam_beta1, self.adam_beta2),
            )
        elif self.optim_method == 'RMSprop':
            self.optimizer = optim.RMSprop(
                model.parameters(),
                lr=self.lr,
                weight_decay=self.weight_decay,
            )
        elif self.optim_method == 'SGD':
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=self.lr,
                                       weight_decay=self.weight_decay,
                                       momentum=self.momentum,
                                       nesterov=True)
        else:
            raise (Exception("Unknown Optimization Method"))

        # for classification
        self.label = torch.zeros(self.batch_size).long()
        self.input = torch.zeros(self.batch_size, self.nchannels,
                                 self.resolution_high, self.resolution_wide)

        if args.cuda:
            self.label = self.label.cuda()
            self.input = self.input.cuda()

        self.input = Variable(self.input)
        self.label = Variable(self.label)

        # logging training
        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['Loss', 'Accuracy']
        self.log_loss_train.register(self.params_loss_train)

        # monitor training
        self.monitor_train = plugins.Monitor()
        self.params_monitor_train = ['Loss', 'Accuracy']
        self.monitor_train.register(self.params_monitor_train)

        # visualize training
        self.visualizer_train = plugins.Visualizer(self.port, 'Train')
        self.params_visualizer_train = {
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Accuracy': {
                'dtype': 'scalar',
                'vtype': 'plot'
            },
            'Image': {
                'dtype': 'image',
                'vtype': 'image'
            },
            'Images': {
                'dtype': 'images',
                'vtype': 'images'
            },
        }
        self.visualizer_train.register(self.params_visualizer_train)

        # display training progress
        self.print_train = 'Train [%d/%d][%d/%d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + item + " %.4f "

        self.evalmodules = []
        self.losses_train = {}
        print(self.model)
Beispiel #8
0
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation

        self.nepochs = args.nepochs

        # for classification
        self.labels = torch.zeros(args.batch_size).long()
        self.inputs = torch.zeros(args.batch_size, args.resolution_high,
                                  args.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging testing
        self.log_loss = plugins.Logger(args.logs_dir, 'TestLogger.txt',
                                       args.save_results)
        params_loss = ['ACC']
        self.log_loss.register(params_loss)

        # monitor testing
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'ACC': {
                'dtype': 'running_mean'
            },
        }
        self.monitor.register(self.params_monitor)

        # visualize testing
        self.visualizer = plugins.Visualizer(args.port, args.env, 'Test')
        params_visualizer = {
            'ACC': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'acc',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 1
                }
            },
            # 'Test_Image': {'dtype': 'image', 'vtype': 'image',
            #                'win': 'test_image'},
            # 'Test_Images': {'dtype': 'images', 'vtype': 'images',
            #                 'win': 'test_images'},
        }
        self.visualizer.register(params_visualizer)

        # display training progress
        self.print_formatter = 'Test [%d/%d]] '
        for item in ['ACC']:
            self.print_formatter += item + " %.4f "

        self.losses = {}
        # self.binage = torch.Tensor([10,22.5,27.5,32.5,37.5,42.5,47.5,55,75])
        # self.binage = torch.Tensor([10,25,35,45,55,75])
        self.binage = torch.Tensor([19, 37.5, 52.5, 77])
Beispiel #9
0
    def __init__(self, args, model, criterion):

        self.args = args
        self.model = model
        self.criterion = criterion

        self.port = args.port
        self.dir_save = args.save

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        if args.nepochs is None:
            self.nepochs = 0
        self.nchannels = args.ngchannels
        self.nechannels = args.nechannels
        self.batch_size = args.batch_size
        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide
        self.nlatent = args.nlatent
        self.out_steps = args.out_steps

        self.learning_rate = args.learning_rate
        self.momentum = args.momentum
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.weight_decay = args.weight_decay
        self.optim_method = getattr(optim, args.optim_method)

        self.optimizer = {}
        self.optimizer["netG"] = self.optim_method(model["netG"].parameters(),
                                                   lr=self.learning_rate)
        self.composition = torch.FloatTensor(self.batch_size, self.nchannels,
                                             self.resolution_high,
                                             self.resolution_wide)
        self.metadata = torch.FloatTensor(self.batch_size, self.nechannels,
                                          self.resolution_high,
                                          self.resolution_wide)
        if args.out_images is None:
            self.out_images = self.batch_size + 8 - (self.batch_size % 8)
        else:
            self.out_images = args.out_images

        if args.cuda:
            self.composition = self.composition.cuda()
            self.metadata = self.metadata.cuda()

        self.composition = Variable(self.composition)
        self.metadata = Variable(self.metadata)

        self.log_loss_train = plugins.Logger(args.logs, 'TrainLogger.txt')
        self.params_loss_train = ['L2']
        self.log_loss_train.register(self.params_loss_train)

        self.log_monitor_train = plugins.Monitor(smoothing=False)
        self.params_monitor_train = ['L2']
        self.log_monitor_train.register(self.params_monitor_train)

        self.log_loss_test = plugins.Logger(args.logs, 'TestLogger.txt')
        self.params_loss_test = ['L2']
        self.log_loss_test.register(self.params_loss_test)

        self.log_monitor_test = plugins.Monitor(smoothing=False)
        self.params_monitor_test = ['L2']
        self.log_monitor_test.register(self.params_monitor_test)

        # visualize training
        self.visualizer_train = plugins.Visualizer(self.port, 'Train',
                                                   args.images, args.pre_name)

        self.params_visualizer_train = {
            'L2': {
                'dtype': 'scalar',
                'vtype': 'plot'
            }
        }

        self.visualizer_train.register(self.params_visualizer_train)

        # visualize testing
        self.visualizer_test = plugins.Visualizer(self.port, 'Test',
                                                  args.images, args.pre_name)

        self.params_visualizer_test = {
            'L2': {
                'dtype': 'scalar',
                'vtype': 'plot'
            }
        }

        self.visualizer_test.register(self.params_visualizer_test)
        self.imgio_test = plugins.ImageIO(args.images, args.pre_name)

        # display training progress
        self.print_train = '[%04d/%04d][%02d/%02d] '
        for item in self.params_loss_train:
            self.print_train = self.print_train + f"{item:4}" + " %8.6f "

        # display testing progress
        self.print_test = '[%d/%d][%d/%d] '
        for item in self.params_loss_test:
            self.print_test = self.print_test + f"{item:4}" + " %8.6f "

        self.giterations = 1
        self.losses_train = {}
        self.losses_test = {}
        self.mu_sigma = {}
    def __init__(self, args, model, criterion, evaluation):
        self.args = args
        self.model = model
        self.criterion = criterion
        self.evaluation = evaluation
        self.save_results = args.save_results

        self.env = args.env
        self.port = args.port
        self.dir_save = args.save_dir
        self.log_type = args.log_type

        self.cuda = args.cuda
        self.nepochs = args.nepochs
        self.batch_size = args.batch_size

        self.resolution_high = args.resolution_high
        self.resolution_wide = args.resolution_wide

        self.lr = args.learning_rate
        self.optim_method = args.optim_method
        self.optim_options = args.optim_options
        self.scheduler_method = args.scheduler_method
        self.scheduler_options = args.scheduler_options
        self.weight_loss = args.weight_loss

        # setting optimizer for multiple modules
        self.module_list = nn.ModuleList()
        for key in list(self.model):
            self.module_list.append(self.model[key])
        self.optimizer = getattr(optim, self.optim_method)(
            self.module_list.parameters(), lr=self.lr, **self.optim_options)
        if self.scheduler_method is not None:
            self.scheduler = getattr(optim.lr_scheduler,
                                     self.scheduler_method)(
                                         self.optimizer,
                                         **self.scheduler_options)

        # for classification
        self.labels = torch.zeros(self.batch_size).long()
        self.inputs = torch.zeros(self.batch_size, self.resolution_high,
                                  self.resolution_wide)

        if args.cuda:
            self.labels = self.labels.cuda()
            self.inputs = self.inputs.cuda()

        self.inputs = Variable(self.inputs)
        self.labels = Variable(self.labels)

        # logging training
        self.log_loss = plugins.Logger(args.logs_dir, 'TrainLogger.txt',
                                       self.save_results)
        self.params_loss = ['LearningRate', 'Loss']  #, 'TAR']
        self.log_loss.register(self.params_loss)

        # monitor training
        self.monitor = plugins.Monitor()
        self.params_monitor = {
            'LearningRate': {
                'dtype': 'running_mean'
            },
            'Loss': {
                'dtype': 'running_mean'
            },
            # 'TAR': {'dtype': 'running_mean'},
        }
        self.monitor.register(self.params_monitor)

        # visualize training
        self.visualizer = plugins.Visualizer(self.port, self.env, 'Train')
        self.params_visualizer = {
            'LearningRate': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'learning_rate',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            'Loss': {
                'dtype': 'scalar',
                'vtype': 'plot',
                'win': 'loss',
                'layout': {
                    'windows': ['train', 'test'],
                    'id': 0
                }
            },
            # 'TAR': {'dtype': 'scalar', 'vtype': 'plot', 'win': 'mAP',
            #         'layout': {'windows': ['train', 'test'], 'id': 0}},
        }
        self.visualizer.register(self.params_visualizer)

        # display training progress
        self.print_formatter = 'Train [%d/%d][%d/%d] '
        for item in self.params_loss:
            self.print_formatter += item + " %.4f "

        self.evalmodules = []
        self.losses = {}