def _net_init(self, init_type='kaiming'): print('==> Initializing the network using [%s]' % init_type) init_weights(self.model, init_type)
def initialize(self, opt): super(SupervisedPoseTransferModel, self).initialize(opt) ################################### # define transformer ################################### if opt.which_model_T == 'resnet': self.netT = networks.ResnetGenerator( input_nc=3 + self.get_pose_dim(opt.pose_type), output_nc=3, ngf=opt.T_nf, norm_layer=networks.get_norm_layer(opt.norm), use_dropout=not opt.no_dropout, n_blocks=9, gpu_ids=opt.gpu_ids) elif opt.which_model_T == 'unet': self.netT = networks.UnetGenerator_v2( input_nc=3 + self.get_pose_dim(opt.pose_type), output_nc=3, num_downs=8, ngf=opt.T_nf, norm_layer=networks.get_norm_layer(opt.norm), use_dropout=not opt.no_dropout, gpu_ids=opt.gpu_ids) else: raise NotImplementedError() if opt.gpu_ids: self.netT.cuda() networks.init_weights(self.netT, init_type=opt.init_type) ################################### # define discriminator ################################### self.use_GAN = self.is_train and opt.loss_weight_gan > 0 if self.use_GAN > 0: self.netD = networks.define_D_from_params( input_nc=3 + self.get_pose_dim(opt.pose_type) if opt.D_cond else 3, ndf=opt.D_nf, which_model_netD='n_layers', n_layers_D=3, norm=opt.norm, which_gan=opt.which_gan, init_type=opt.init_type, gpu_ids=opt.gpu_ids) else: self.netD = None ################################### # loss functions ################################### if self.is_train: self.loss_functions = [] self.schedulers = [] self.optimizers = [] self.crit_L1 = nn.L1Loss() self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids) # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids) self.crit_psnr = networks.PSNR() self.crit_ssim = networks.SSIM() self.loss_functions += [self.crit_L1, self.crit_vgg] self.optim = torch.optim.Adam(self.netT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers += [self.optim] if self.use_GAN: self.crit_GAN = networks.GANLoss( use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor) self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2)) self.loss_functions.append(self.use_GAN) self.optimizers.append(self.optim_D) # todo: add pose loss for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt)) self.fake_pool = ImagePool(opt.pool_size) ################################### # load trained model ################################### if not self.is_train: self.load_network(self.netT, 'netT', opt.which_model)
shuffle=True) test_data = DataLoader('./data_256/valA', './data_256/valB', transform=transform) test_data_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=1, shuffle=False) test_input, test_target = test_data_loader.__iter__().__next__() # Models G = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=6) D = Discriminator(input_nc=6, ndf=64) #G.cuda() #D.cuda() init_weights(G) init_weights(D) #G.init_weights(mean=0.0, std=0.02) #D.init_weights(mean=0.0, std=0.02) # Loss function #BCE_loss = torch.nn.BCELoss()#.cuda() BCE_loss = GANLoss() L1_loss = torch.nn.L1Loss() #.cuda() # Optimizers G_optimizer = torch.optim.Adam(G.parameters(), lr=params.lrG, betas=(params.beta1, params.beta2)) D_optimizer = torch.optim.Adam(D.parameters(), lr=params.lrD,
def initialize(self, opt): super(VUnetPoseTransferModel, self).initialize(opt) ################################### # define transformer ################################### self.netT = networks.VariationalUnet( input_nc_dec = self.get_pose_dim(opt.pose_type), input_nc_enc = self.get_appearance_dim(opt.appearance_type), output_nc = self.get_output_dim(opt.output_type), nf = opt.vunet_nf, max_nf = opt.vunet_max_nf, input_size = opt.fine_size, n_latent_scales = opt.vunet_n_latent_scales, bottleneck_factor = opt.vunet_bottleneck_factor, box_factor = opt.vunet_box_factor, n_residual_blocks = 2, norm_layer = networks.get_norm_layer(opt.norm), activation = nn.ReLU(False), use_dropout = False, gpu_ids = opt.gpu_ids, output_tanh = False, ) if opt.gpu_ids: self.netT.cuda() networks.init_weights(self.netT, init_type=opt.init_type) ################################### # define discriminator ################################### self.use_GAN = self.is_train and opt.loss_weight_gan > 0 if self.use_GAN: self.netD = networks.define_D_from_params( input_nc=3+self.get_pose_dim(opt.pose_type) if opt.D_cond else 3, ndf=opt.D_nf, which_model_netD='n_layers', n_layers_D=opt.D_n_layer, norm=opt.norm, which_gan=opt.which_gan, init_type=opt.init_type, gpu_ids=opt.gpu_ids) else: self.netD = None ################################### # loss functions ################################### self.crit_psnr = networks.PSNR() self.crit_ssim = networks.SSIM() if self.is_train: self.optimizers =[] self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids, opt.content_layer_weight, opt.style_layer_weight, opt.shifted_style) # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids) self.optim = torch.optim.Adam(self.netT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) self.optimizers += [self.optim] if self.use_GAN: self.crit_GAN = networks.GANLoss(use_lsgan=opt.which_gan=='lsgan', tensor=self.Tensor) self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optim_D) # todo: add pose loss self.fake_pool = ImagePool(opt.pool_size) ################################### # load trained model ################################### if not self.is_train: self.load_network(self.netT, 'netT', opt.which_epoch) elif opt.continue_train: self.load_network(self.netT, 'netT', opt.which_epoch) self.load_optim(self.optim, 'optim', opt.which_epoch) if self.use_GAN: self.load_network(self.netD, 'netD', opt.which_epoch) self.load_optim(self.optim_D, 'optim_D', opt.which_epoch) ################################### # schedulers ################################### if self.is_train: self.schedulers = [] for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt))
def initialize(self, opt): super(EncoderDecoderFramework, self).initialize(opt) ################################### # load/define networks ################################### if opt.use_shape: self.encoder_type = 'shape' self.encoder_name = 'shape_encoder' self.decoder_name = 'decoder' elif opt.use_edge: self.encoder_type = 'edge' self.encoder_name = 'edge_encoder' self.decoder_name = 'decoder' elif opt.use_color: self.encoder_type = 'color' self.encoder_name = 'color_encoder' self.decoder_name = 'decoder' else: raise ValueError( 'either use_shape, use_edge, use_color should be set') # encoder self.encoder = networks.define_image_encoder(opt, self.encoder_type) # decoder if self.encoder_type == 'shape': ndowns = opt.shape_ndowns nf = opt.shape_nf nof = opt.shape_nof output_nc = 7 output_activation = None assert opt.decode_guided == False elif self.encoder_type == 'edge': ndowns = opt.edge_ndowns nf = opt.edge_nf nof = opt.edge_nof output_nc = 1 output_activation = None elif self.encoder_type == 'color': ndowns = opt.color_ndowns nf = opt.color_nf nof = opt.color_nof output_nc = 3 output_activation = nn.Tanh if opt.encoder_type in {'normal', 'st'}: self.feat_size = 256 // 2**(opt.edge_ndowns) self.mid_feat_size = self.feat_size else: self.feat_size = 1 self.mid_feat_size = 8 self.use_concat_net = False if opt.decode_guided: if self.feat_size > 1: self.decoder = networks.define_image_decoder_from_params( input_nc=nof + opt.shape_nc, output_nc=output_nc, nf=nf, num_ups=ndowns, norm=opt.norm, output_activation=output_activation, gpu_ids=opt.gpu_ids, init_type=opt.init_type) else: self.decoder = networks.define_image_decoder_from_params( input_nc=nof, output_nc=output_nc, nf=nf, num_ups=5, norm=opt.norm, output_activation=output_activation, gpu_ids=opt.gpu_ids, init_type=opt.init_type) self.concat_net = networks.FeatureConcatNetwork( feat_nc=nof, guide_nc=opt.shape_nc, nblocks=3, norm=opt.norm, gpu_ids=opt.gpu_ids) if len(self.gpu_ids) > 0: self.concat_net.cuda() networks.init_weights(self.concat_net, opt.init_type) self.use_concat_net = True print('encoder_decoder contains a feature_concat_network!') else: if self.feat_size > 1: self.decoder = networks.define_image_decoder_from_params( input_nc=nof, output_nc=output_nc, nf=nf, num_ups=ndowns, norm=opt.norm, output_activation=output_activation, gpu_ids=opt.gpu_ids, init_type=opt.init_type) else: self.decoder = networks.define_image_decoder_from_params( input_nc=nof, output_nc=output_nc, nf=nf, num_ups=8, norm=opt.norm, output_activation=output_activation, gpu_ids=opt.gpu_ids, init_type=opt.init_type) if not self.is_train or (self.is_train and self.opt.continue_train): self.load_network(self.encoder, self.encoder_name, opt.which_opoch) self.load_network(self.decoder, self.decoder_name, opt.which_opoch) if self.use_concat_net: self.load_network(self.concat_net, 'concat_net', opt.which_opoch) # loss functions self.loss_functions = [] self.schedulers = [] self.crit_L1 = networks.SmoothLoss(nn.L1Loss()) self.crit_CE = networks.SmoothLoss(nn.CrossEntropyLoss()) self.loss_functions += [self.crit_L1, self.crit_CE] self.optim = torch.optim.Adam([{ 'params': self.encoder.parameters() }, { 'params': self.decoder.parameters() }], lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [self.optim] for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt))
def initialize(self, opt): super(FlowRegressionModel, self).initialize(opt) ################################### # define flow networks ################################### if opt.which_model == 'unet': self.netF = networks.FlowUnet( input_nc=self.get_input_dim(opt.input_type1) + self.get_input_dim(opt.input_type2), nf=opt.nf, start_scale=opt.start_scale, num_scale=opt.num_scale, norm=opt.norm, gpu_ids=opt.gpu_ids, ) elif opt.which_model == 'unet_v2': self.netF = networks.FlowUnet_v2( input_nc=self.get_input_dim(opt.input_type1) + self.get_input_dim(opt.input_type2), nf=opt.nf, max_nf=opt.max_nf, start_scale=opt.start_scale, num_scales=opt.num_scale, norm=opt.norm, gpu_ids=opt.gpu_ids, ) if opt.gpu_ids: self.netF.cuda() networks.init_weights(self.netF, init_type=opt.init_type) ################################### # loss and optimizers ################################### self.crit_flow = networks.MultiScaleFlowLoss( start_scale=opt.start_scale, num_scale=opt.num_scale, loss_type=opt.flow_loss_type) self.crit_vis = nn.CrossEntropyLoss( ) #(0-visible, 1-invisible, 2-background) if opt.use_ss_flow_loss: self.crit_flow_ss = networks.SS_FlowLoss(loss_type='l1') if self.is_train: self.optimizers = [] self.optim = torch.optim.Adam(self.netF.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) self.optimizers.append(self.optim) ################################### # load trained model ################################### if not self.is_train: # load trained model for test print('load pretrained model') self.load_network(self.netF, 'netF', opt.which_epoch) elif opt.resume_train: # resume training print('resume training') self.load_network(self.netF, 'netF', opt.last_epoch) self.load_optim(self.optim, 'optim', opt.last_epoch) ################################### # schedulers ################################### if self.is_train: self.schedulers = [] for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt))
def main(config, needs_save, i): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices assert config.train_dataset.root_dir_path == config.val_dataset.root_dir_path # train_patient_ids, val_patient_ids = divide_patients(config.train_dataset.root_dir_path) train_patient_ids, val_patient_ids = get_cv_splits( config.train_dataset.root_dir_path, i) seed = check_manual_seed() print('Using seed: {}'.format(seed)) class_name_to_index = config.label_to_id._asdict() index_to_class_name = {v: k for k, v in class_name_to_index.items()} train_data_loader = get_data_loader( mode='train', dataset_name=config.train_dataset.dataset_name, root_dir_path=config.train_dataset.root_dir_path, patient_ids=train_patient_ids, batch_size=config.train_dataset.batch_size, num_workers=config.train_dataset.num_workers, volume_size=config.train_dataset.volume_size, ) val_data_loader = get_data_loader( mode='val', dataset_name=config.val_dataset.dataset_name, root_dir_path=config.val_dataset.root_dir_path, patient_ids=val_patient_ids, batch_size=config.val_dataset.batch_size, num_workers=config.val_dataset.num_workers, volume_size=config.val_dataset.volume_size, ) model = ResUNet( input_dim=config.model.input_dim, output_dim=config.model.output_dim, filters=config.model.filters, ) print(model) if config.run.use_cuda: model.cuda() model = nn.DataParallel(model) if config.model.saved_model: print('Loading saved model: {}'.format(config.model.saved_model)) model.load_state_dict(torch.load(config.model.saved_model)) else: print('Initializing weights.') init_weights(model, init_type=config.model.init_type) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.optimizer.lr, betas=config.optimizer.betas, weight_decay=config.optimizer.weight_decay) dice_loss = SoftDiceLoss() focal_loss = FocalLoss( gamma=config.focal_loss.gamma, alpha=config.focal_loss.alpha, ) active_contour_loss = ActiveContourLoss( weight=config.active_contour_loss.weight, ) dice_coeff = DiceCoefficient( n_classes=config.metric.n_classes, index_to_class_name=index_to_class_name, ) one_hot_encoder = OneHotEncoder( n_classes=config.metric.n_classes, ).forward def train(engine, batch): adjust_learning_rate(optimizer, engine.state.epoch, initial_lr=config.optimizer.lr, n_epochs=config.run.n_epochs, gamma=config.optimizer.gamma) model.train() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() optimizer.zero_grad() output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) l_total = l_dice + l_focal + l_active_contour l_total.backward() optimizer.step() m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures def evaluate(engine, batch): model.eval() image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): output = model(image) target = one_hot_encoder(label)[:, 1:, ...] l_dice = dice_loss(output, target) l_focal = focal_loss(output, target) l_active_contour = active_contour_loss(output, target) m_dice = dice_coeff.update(output.detach(), label) measures = { 'SoftDiceLoss': l_dice.item(), 'FocalLoss': l_focal.item(), 'ActiveContourLoss': l_active_contour.item(), } measures.update(m_dice) if config.run.use_cuda: torch.cuda.synchronize() return measures output_dir_path = get_output_dir_path(config, i) trainer = Engine(train) evaluator = Engine(evaluate) timer = Timer(average=True) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir_path, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.run.n_epochs + 1, create_dir=True, ) monitoring_metrics = ['SoftDiceLoss', 'FocalLoss', 'ActiveContourLoss'] monitoring_metrics += class_name_to_index.keys() for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( evaluator, metric) pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=monitoring_metrics) pbar.attach(evaluator, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def call_save_config(engine): if needs_save: return save_config(engine, config, seed, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_save_logs(engine): if needs_save: return save_logs('train', engine, config, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def call_print_times(engine): return print_times(engine, config, pbar, timer) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): evaluator.run(val_data_loader, 1) if needs_save: save_logs('val', evaluator, config, output_dir_path) save_images(evaluator, trainer.state.epoch) def save_images(evaluator, epoch): batch = evaluator.state.batch image = batch['image'] label = batch['label'] if config.run.use_cuda: image = image.cuda(non_blocking=True).float() label = label.cuda(non_blocking=True).long() else: image = image.float() label = label.long() with torch.no_grad(): pred = model(image) output = torch.ones_like(label) mask_0 = pred[:, 0, ...] < 0.5 mask_1 = pred[:, 1, ...] < 0.5 mask_2 = pred[:, 2, ...] < 0.5 mask = mask_0 * mask_1 * mask_2 pred = pred.argmax(1) output += pred output[mask] = 0 image = image.detach().cpu().float() label = label.detach().cpu().unsqueeze(1).float() output = output.detach().cpu().unsqueeze(1).float() z_middle = image.shape[-1] // 2 image = image[:, 0, ..., z_middle] label = label[:, 0, ..., z_middle] output = output[:, 0, ..., z_middle] if config.save.image_vmax is not None: vmax = config.save.image_vmax else: vmax = image.max() if config.save.image_vmin is not None: vmin = config.save.image_vmin else: vmin = image.min() image = np.clip(image, vmin, vmax) image -= vmin image /= (vmax - vmin) image *= 255.0 save_path = os.path.join(output_dir_path, 'result_{}.png'.format(epoch)) save_images_via_plt(image, label, output, config.save.n_save_images, config, save_path) if needs_save: trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optim': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.run.n_epochs, config.run.n_epochs * len(train_data_loader))) trainer.run(train_data_loader, config.run.n_epochs)
def initialize(self, opt): super(PoseTransferModel, self).initialize(opt) ################################### # define generator ################################### if opt.which_model_G == 'unet': self.netG = networks.UnetGenerator( input_nc=self.get_tensor_dim('+'.join( [opt.G_appearance_type, opt.G_pose_type])), output_nc=3, nf=opt.G_nf, max_nf=opt.G_max_nf, num_scales=opt.G_n_scale, n_residual_blocks=2, norm=opt.G_norm, activation=nn.LeakyReLU(0.1) if opt.G_activation == 'leaky_relu' else nn.ReLU(), use_dropout=opt.use_dropout, gpu_ids=opt.gpu_ids) elif opt.which_model_G == 'dual_unet': self.netG = networks.DualUnetGenerator( pose_nc=self.get_tensor_dim(opt.G_pose_type), appearance_nc=self.get_tensor_dim(opt.G_appearance_type), output_nc=3, aux_output_nc=[], nf=opt.G_nf, max_nf=opt.G_max_nf, num_scales=opt.G_n_scale, num_warp_scales=opt.G_n_warp_scale, n_residual_blocks=2, norm=opt.G_norm, vis_mode=opt.G_vis_mode, activation=nn.LeakyReLU(0.1) if opt.G_activation == 'leaky_relu' else nn.ReLU(), use_dropout=opt.use_dropout, no_end_norm=opt.G_no_end_norm, gpu_ids=opt.gpu_ids, ) if opt.gpu_ids: self.netG.cuda() networks.init_weights(self.netG, init_type=opt.init_type) ################################### # define external pixel warper ################################### if opt.G_pix_warp: pix_warp_n_scale = opt.G_n_scale self.netPW = networks.UnetGenerator_MultiOutput( input_nc=self.get_tensor_dim(opt.G_pix_warp_input_type), output_nc=[1], # only use one output branch (weight mask) nf=32, max_nf=128, num_scales=pix_warp_n_scale, n_residual_blocks=2, norm=opt.G_norm, activation=nn.ReLU(False), use_dropout=False, gpu_ids=opt.gpu_ids) if opt.gpu_ids: self.netPW.cuda() networks.init_weights(self.netPW, init_type=opt.init_type) ################################### # define discriminator ################################### self.use_gan = self.is_train and self.opt.loss_weight_gan > 0 if self.use_gan: self.netD = networks.NLayerDiscriminator( input_nc=self.get_tensor_dim(opt.D_input_type_real), ndf=opt.D_nf, n_layers=opt.D_n_layers, use_sigmoid=(opt.gan_type == 'dcgan'), output_bias=True, gpu_ids=opt.gpu_ids, ) if opt.gpu_ids: self.netD.cuda() networks.init_weights(self.netD, init_type=opt.init_type) ################################### # load optical flow model ################################### if opt.flow_on_the_fly: self.netF = load_flow_network(opt.pretrained_flow_id, opt.pretrained_flow_epoch, opt.gpu_ids) self.netF.eval() if opt.gpu_ids: self.netF.cuda() ################################### # loss and optimizers ################################### self.crit_psnr = networks.PSNR() self.crit_ssim = networks.SSIM() if self.is_train: self.crit_vgg = networks.VGGLoss( opt.gpu_ids, shifted_style=opt.shifted_style_loss, content_weights=opt.vgg_content_weights) if opt.G_pix_warp: # only optimze netPW self.optim = torch.optim.Adam(self.netPW.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) else: self.optim = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay) self.optimizers = [self.optim] if self.use_gan: self.crit_gan = networks.GANLoss( use_lsgan=(opt.gan_type == 'lsgan')) if self.gpu_ids: self.crit_gan.cuda() self.optim_D = torch.optim.Adam( self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay_D) self.optimizers += [self.optim_D] ################################### # load trained model ################################### if not self.is_train: # load trained model for testing self.load_network(self.netG, 'netG', opt.which_epoch) if opt.G_pix_warp: self.load_network(self.netPW, 'netPW', opt.which_epoch) elif opt.pretrained_G_id is not None: # load pretrained network self.load_network(self.netG, 'netG', opt.pretrained_G_epoch, opt.pretrained_G_id) elif opt.resume_train: # resume training self.load_network(self.netG, 'netG', opt.which_epoch) self.load_optim(self.optim, 'optim', opt.which_epoch) if self.use_gan: self.load_network(self.netD, 'netD', opt.which_epoch) self.load_optim(self.optim_D, 'optim_D', opt.which_epoch) if opt.G_pix_warp: self.load_network(self.netPW, 'netPW', opt.which_epoch) ################################### # schedulers ################################### if self.is_train: self.schedulers = [] for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt))
def __init__(self, opt): super(DASGIL, self).__init__() self.opt = opt self.isTrain = opt.isTrain if self.opt.gpu_ids >= 0: self.Tensor = torch.cuda.FloatTensor else: self.Tensor = torch.FloatTensor self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # generator self.generator = Generator(opt) if opt.isTrain: self.generator.weight_init(0, 0.02) self.gen_parameters = list(self.generator.parameters()) self.gen_optimizer = torch.optim.Adam(self.gen_parameters, lr=opt.lr, betas=(0.9, 0.999)) # discriminator assert opt.dis_type == "FD" or opt.dis_type == "CD" if opt.dis_type == "FD": self.dis_f = FlattenDiscriminator(opt.dis_nc, opt.dis_nlayers) elif opt.dis_type == "CD": self.dis_f = CascadeDiscriminator() else: print("ERROR: only FD or CD is supported") init_weights(self.dis_f, 'normal', opt=self.opt) dis_params = list(self.dis_f.parameters()) self.dis_optimizer = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=opt.lr_dis, betas=(0.5, 0.9)) weight = torch.ones(self.opt.num_classes) self.criterion_corssentropy = CrossEntropyLoss2d(weight) # initialize inputs self.input_GAN_real = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_rgb_A = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_depth_A = self.Tensor(opt.batch_size, 1, opt.resized_h, opt.resized_w) self.input_seg_A = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_rgb_A_prime = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_depth_A_prime = self.Tensor(opt.batch_size, 1, opt.resized_h, opt.resized_w) self.input_seg_A_prime = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_rgb_B = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) self.input_depth_B = self.Tensor(opt.batch_size, 1, opt.resized_h, opt.resized_w) self.input_seg_B = self.Tensor(opt.batch_size, 3, opt.resized_h, opt.resized_w) # load checkpoints if not opt.isTrain or opt.continue_train: save_filename_generator = '%d_net_%s' % (opt.which_epoch, 'gen') save_path_generator = os.path.join(self.save_dir, save_filename_generator) filename_generator = save_path_generator + '.pth' self.generator.load_state_dict(torch.load(filename_generator)) if opt.continue_train: save_filename_dis = '%d_net_%s' % (opt.which_epoch, 'dis') save_path_dis = os.path.join(self.save_dir, save_filename_dis) filename_dis = save_path_dis + '.pth' self.dis_f.load_state_dict(torch.load(filename_dis)) if opt.isTrain: # lr scheduler update self.gen_scheduler = lr_scheduler.StepLR( self.gen_optimizer, step_size=opt.step_lr_epoch, gamma=opt.gamma_lr) self.gen_scheduler.last_epoch = opt.which_epoch self.dis_scheduler = lr_scheduler.StepLR( self.dis_optimizer, step_size=opt.step_lr_epoch, gamma=opt.gamma_lr) self.dis_scheduler.last_epoch = opt.which_epoch
def main(): # instantiate model and initialize weights model = ENet() networks.print_network(model) networks.init_weights(model, init_type='normal') model.init_convFilter(trainable=srm_trainable) if args.cuda: model.cuda() print('using pretrained model') checkpoint = torch.load(project_root + args.log_dir + '/checkpoint_300.pth') model.load_state_dict(checkpoint['state_dict']) args.lr = args.lr * 0.001 threshold = THRESHOLD_MAX # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() L1_criterion = nn.L1Loss(reduction='sum').cuda() if not srm_trainable: params = [] for name, param in model.named_parameters(): if name.find('convFilter1') == -1: params += [param] optimizer = create_optimizer(params, args.lr) else: optimizer = create_optimizer(model.parameters(), args.lr) nature_error_itr_global = [] for itr in np.arange(1, 11): args.dataroot = dst_dir nature_error_itr_local = [] # adding negative samples into the original training dataset construct_negative_samples(itr) train_loader = myDataset.DataLoaderHalf( myDataset.MyDataset( args, transforms.Compose([ transforms.RandomCrop(233), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])), batch_size=args.batch_size, shuffle=True, half_constraint=True, sampler_type='RandomBalancedSampler', **kwargs) print('The number of train data:{}'.format(len(train_loader.dataset))) args.epochs = 15 train_multi(train_loader, optimizer, model, criterion, L1_criterion, val_loader, itr, \ nature_error_itr_local, nature_error_itr_global) # start from itr = 1 if len(nature_error_itr_local) > 0: adv_model_num, adv_model_idx = adv_model_selection( nature_error_itr_local, threshold, itr) if adv_model_num < 1: break print(nature_error_itr_global) print(len(nature_error_itr_global) / (args.epochs - args.epochs // 2)) final_model_selection(nature_error_itr_global, threshold)
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if type( norm_layer ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) init_weights(self.model, 'xavier')
def initialize(self, opt): super(TwoStagePoseTransferModel, self).initialize(opt) ################################### # load pretrained stage-1 (coarse) network ################################### self._create_stage_1_net(opt) ################################### # define stage-2 (refine) network ################################### # local patch encoder if opt.which_model_s2e == 'patch_embed': self.netT_s2e = networks.LocalPatchEncoder( n_patch=len(opt.patch_indices), input_nc=3, output_nc=opt.s2e_nof, nf=opt.s2e_nf, max_nf=opt.s2e_max_nf, input_size=opt.patch_size, bottleneck_factor=opt.s2e_bottleneck_factor, n_residual_blocks=2, norm_layer=networks.get_norm_layer(opt.norm), activation=nn.ReLU(False), use_dropout=False, gpu_ids=opt.gpu_ids, ) s2e_nof = opt.s2e_nof elif opt.which_model_s2e == 'patch': self.netT_s2e = networks.LocalPatchRearranger( n_patch=len(opt.patch_indices), image_size=opt.fine_size, ) s2e_nof = 3 elif opt.which_model_s2e == 'seg_embed': self.netT_s2e = networks.SegmentRegionEncoder( seg_nc=self.opt.seg_nc, input_nc=3, output_nc=opt.s2e_nof, nf=opt.s2d_nf, input_size=opt.fine_size, n_blocks=3, norm_layer=networks.get_norm_layer(opt.norm), activation=nn.ReLU, use_dropout=False, grid_level=opt.s2e_grid_level, gpu_ids=opt.gpu_ids, ) s2e_nof = opt.s2e_nof + opt.s2e_grid_level else: raise NotImplementedError() if opt.gpu_ids: self.netT_s2e.cuda() # decoder if self.opt.which_model_s2d == 'resnet': self.netT_s2d = networks.ResnetGenerator( input_nc=3 + s2e_nof, output_nc=3, ngf=opt.s2d_nf, norm_layer=networks.get_norm_layer(opt.norm), activation=nn.ReLU, use_dropout=False, n_blocks=opt.s2d_nblocks, gpu_ids=opt.gpu_ids, output_tanh=False, ) elif self.opt.which_model_s2d == 'unet': self.netT_s2d = networks.UnetGenerator_v2( input_nc=3 + s2e_nof, output_nc=3, num_downs=8, ngf=opt.s2d_nf, max_nf=opt.s2d_nf * 2**3, norm_layer=networks.get_norm_layer(opt.norm), use_dropout=False, gpu_ids=opt.gpu_ids, output_tanh=False, ) elif self.opt.which_model_s2d == 'rpresnet': self.netT_s2d = networks.RegionPropagationResnetGenerator( input_nc=3 + s2e_nof, output_nc=3, ngf=opt.s2d_nf, norm_layer=networks.get_norm_layer(opt.norm), activation=nn.ReLU, use_dropout=False, nblocks=opt.s2d_nblocks, gpu_ids=opt.gpu_ids, output_tanh=False) else: raise NotImplementedError() if opt.gpu_ids: self.netT_s2d.cuda() ################################### # define discriminator ################################### self.use_GAN = self.is_train and opt.loss_weight_gan > 0 if self.use_GAN: self.netD = networks.define_D_from_params( input_nc=3 + self.get_pose_dim(opt.pose_type) if opt.D_cond else 3, ndf=opt.D_nf, which_model_netD='n_layers', n_layers_D=opt.D_n_layer, norm=opt.norm, which_gan=opt.which_gan, init_type=opt.init_type, gpu_ids=opt.gpu_ids) else: self.netD = None ################################### # loss functions ################################### self.crit_psnr = networks.PSNR() self.crit_ssim = networks.SSIM() if self.is_train: self.optimizers = [] self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids, opt.content_layer_weight, opt.style_layer_weight, opt.shifted_style) self.optim = torch.optim.Adam([{ 'params': self.netT_s2e.parameters() }, { 'params': self.netT_s2d.parameters() }], lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optim) if opt.train_s1: self.optim_s1 = torch.optim.Adam(self.netT_s1.parameters(), lr=opt.lr_s1, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optim_s1) if self.use_GAN: self.crit_GAN = networks.GANLoss( use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor) self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2)) self.optimizers.append(self.optim_D) self.fake_pool = ImagePool(opt.pool_size) ################################### # init/load model ################################### if self.is_train: if not opt.continue_train: self.load_network(self.netT_s1, 'netT', 'latest', self.opt_s1.id) networks.init_weights(self.netT_s2e, init_type=opt.init_type) networks.init_weights(self.netT_s2d, init_type=opt.init_type) if self.use_GAN: networks.init_weights(self.netD, init_type=opt.init_type) else: self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch) self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch) self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch) self.load_optim(self.optim, 'optim', opt.which_epoch) if self.use_GAN: self.load_network(self.netD, 'netD', opt.which_epoch) self.load_optim(self.optim_D, 'optim_D', opt.which_epoch) else: self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch) self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch) self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch) ################################### # schedulers ################################### if self.is_train: self.schedulers = [] for optim in self.optimizers: self.schedulers.append(networks.get_scheduler(optim, opt))