def initialize(self, args): BaseModel.initialize(self, args) self.nb = args['batch_size'] sizeH, sizeW = args['fineSizeH'], args['fineSizeW'] self.input_A = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) self.input_B = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) self.input_A_label = torch.cuda.LongTensor(self.nb, args['input_nc'], sizeH, sizeW) self.netG = networks.netG().cuda(device_id=args['device_ids'][0]) self.netD = define_D( args['net_D']).cuda(device_id=args['device_ids'][0]) self.deeplabPart1 = networks.DeeplabPool1().cuda( device_id=args['device_ids'][0]) self.deeplabPart2 = networks.DeeplabPool12Pool5().cuda( device_id=args['device_ids'][0]) self.deeplabPart3 = networks.DeeplabPool52Fc8_interp().cuda( device_id=args['device_ids'][0]) # define loss functions self.criterionCE = torch.nn.CrossEntropyLoss(size_average=False) self.criterionAdv = networks.Advloss(use_lsgan=args['use_lsgan'], tensor=self.Tensor) if not args['resume']: #initialize networks self.netG.apply(weights_init) self.netD.apply(weights_init) pretrained_dict = torch.load(args['weigths_pool'] + '/' + args['pretrain_model']) self.deeplabPart1.weights_init(pretrained_dict=pretrained_dict) self.deeplabPart2.weights_init(pretrained_dict=pretrained_dict) self.deeplabPart3.weights_init(pretrained_dict=pretrained_dict) # initialize optimizers self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=args['lr_gan'], betas=(args['beta1'], 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=args['lr_gan'], betas=(args['beta1'], 0.999)) ignored_params = list(map(id, self.deeplabPart3.fc8_1.parameters())) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_2.parameters()))) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_3.parameters()))) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_4.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.deeplabPart3.parameters()) base_params = base_params + filter(lambda p: True, self.deeplabPart1.parameters()) base_params = base_params + filter(lambda p: True, self.deeplabPart2.parameters()) self.optimizer_P = torch.optim.SGD([ { 'params': base_params }, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'bias'), 'lr': args['l_rate'] * 20 }, ], lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) #netG_params = filter(lambda p: True, self.netG.parameters()) self.optimizer_R = torch.optim.SGD( [ { 'params': base_params }, #{'params': netG_params, 'lr': args['l_rate'] * 100}, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'bias'), 'lr': args['l_rate'] * 20 }, ], lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) print('---------- Networks initialized -------------') networks.print_network(self.netG) networks.print_network(self.netD) networks.print_network(self.deeplabPart1) networks.print_network(self.deeplabPart2) networks.print_network(self.deeplabPart3) print('-----------------------------------------------')
def initialize(self, args): BaseModel.initialize(self, args) self.if_adv_train = args['if_adv_train'] self.Iter = 0 self.interval_g2 = args['interval_g2'] self.interval_d2 = args['interval_d2'] self.nb = args['batch_size'] sizeH, sizeW = args['fineSizeH'], args['fineSizeW'] self.tImageA = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) self.tImageB = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) self.tLabelA = torch.cuda.LongTensor(self.nb, 1, sizeH, sizeW) self.tOnehotLabelA = self.Tensor(self.nb, args['label_nums'], sizeH, sizeW) self.loss_G = Variable() self.loss_D = Variable() self.netG1 = networks.netG().cuda(device_id=args['device_ids'][0]) self.netD1 = define_D(args['net_d1'], 512).cuda(device_id=args['device_ids'][0]) self.netD2 = define_D( args['net_d2'], args['label_nums']).cuda(device_id=args['device_ids'][0]) self.deeplabPart1 = networks.DeeplabPool1().cuda( device_id=args['device_ids'][0]) self.deeplabPart2 = networks.DeeplabPool12Pool5().cuda( device_id=args['device_ids'][0]) self.deeplabPart3 = networks.DeeplabPool52Fc8_interp( output_nc=args['label_nums']).cuda(device_id=args['device_ids'][0]) # define loss functions self.criterionCE = torch.nn.CrossEntropyLoss(size_average=False) self.criterionAdv = networks.Advloss(use_lsgan=args['use_lsgan'], tensor=self.Tensor) if not args['resume']: #initialize networks self.netG1.apply(weights_init) self.netD1.apply(weights_init) self.netD2.apply(weights_init) pretrained_dict = torch.load(args['weigths_pool'] + '/' + args['pretrain_model']) self.deeplabPart1.weights_init(pretrained_dict=pretrained_dict) self.deeplabPart2.weights_init(pretrained_dict=pretrained_dict) self.deeplabPart3.weights_init(pretrained_dict=pretrained_dict) # initialize optimizers self.optimizer_G1 = torch.optim.Adam(self.netG1.parameters(), lr=args['lr_g1'], betas=(args['beta1'], 0.999)) self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(), lr=args['lr_g1'], betas=(args['beta1'], 0.999)) self.optimizer_G2 = torch.optim.Adam( [{ 'params': self.deeplabPart1.parameters() }, { 'params': self.deeplabPart2.parameters() }, { 'params': self.deeplabPart3.parameters() }], lr=args['lr_g2'], betas=(args['beta1'], 0.999)) self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(), lr=args['lr_g2'], betas=(args['beta1'], 0.999)) ignored_params = list(map(id, self.deeplabPart3.fc8_1.parameters())) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_2.parameters()))) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_3.parameters()))) ignored_params.extend( list(map(id, self.deeplabPart3.fc8_4.parameters()))) base_params = filter(lambda p: id(p) not in ignored_params, self.deeplabPart3.parameters()) base_params = base_params + filter(lambda p: True, self.deeplabPart1.parameters()) base_params = base_params + filter(lambda p: True, self.deeplabPart2.parameters()) deeplab_params = [ { 'params': base_params }, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'weight'), 'lr': args['l_rate'] * 10 }, { 'params': get_parameters(self.deeplabPart3.fc8_1, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_2, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_3, 'bias'), 'lr': args['l_rate'] * 20 }, { 'params': get_parameters(self.deeplabPart3.fc8_4, 'bias'), 'lr': args['l_rate'] * 20 }, ] self.optimizer_P = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) self.optimizer_R = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) print('---------- Networks initialized -------------') networks.print_network(self.netG1) networks.print_network(self.netD1) networks.print_network(self.netD2) networks.print_network(self.deeplabPart1) networks.print_network(self.deeplabPart2) networks.print_network(self.deeplabPart3) print('-----------------------------------------------')