Esempio n. 1
0
 def _net_init(self, init_type='kaiming'):
     print('==> Initializing the network using [%s]' % init_type)
     init_weights(self.model, init_type)
    def initialize(self, opt):
        super(SupervisedPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        if opt.which_model_T == 'resnet':
            self.netT = networks.ResnetGenerator(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                n_blocks=9,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_T == 'unet':
            self.netT = networks.UnetGenerator_v2(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                num_downs=8,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                gpu_ids=opt.gpu_ids)
        else:
            raise NotImplementedError()

        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN > 0:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=3,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        if self.is_train:
            self.loss_functions = []
            self.schedulers = []
            self.optimizers = []

            self.crit_L1 = nn.L1Loss()
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.crit_psnr = networks.PSNR()
            self.crit_ssim = networks.SSIM()
            self.loss_functions += [self.crit_L1, self.crit_vgg]
            self.optim = torch.optim.Adam(self.netT.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.loss_functions.append(self.use_GAN)
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_model)
                                                shuffle=True)

test_data = DataLoader('./data_256/valA',
                       './data_256/valB',
                       transform=transform)
test_data_loader = torch.utils.data.DataLoader(dataset=test_data,
                                               batch_size=1,
                                               shuffle=False)
test_input, test_target = test_data_loader.__iter__().__next__()

# Models
G = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=6)
D = Discriminator(input_nc=6, ndf=64)
#G.cuda()
#D.cuda()
init_weights(G)
init_weights(D)
#G.init_weights(mean=0.0, std=0.02)
#D.init_weights(mean=0.0, std=0.02)

# Loss function
#BCE_loss = torch.nn.BCELoss()#.cuda()
BCE_loss = GANLoss()
L1_loss = torch.nn.L1Loss()  #.cuda()

# Optimizers
G_optimizer = torch.optim.Adam(G.parameters(),
                               lr=params.lrG,
                               betas=(params.beta1, params.beta2))
D_optimizer = torch.optim.Adam(D.parameters(),
                               lr=params.lrD,
Esempio n. 4
0
    def initialize(self, opt):
        super(VUnetPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        self.netT = networks.VariationalUnet(
            input_nc_dec = self.get_pose_dim(opt.pose_type),
            input_nc_enc = self.get_appearance_dim(opt.appearance_type),
            output_nc = self.get_output_dim(opt.output_type),
            nf = opt.vunet_nf,
            max_nf = opt.vunet_max_nf,
            input_size = opt.fine_size,
            n_latent_scales = opt.vunet_n_latent_scales,
            bottleneck_factor = opt.vunet_bottleneck_factor,
            box_factor = opt.vunet_box_factor,
            n_residual_blocks = 2,
            norm_layer = networks.get_norm_layer(opt.norm),
            activation = nn.ReLU(False),
            use_dropout = False,
            gpu_ids = opt.gpu_ids,
            output_tanh = False,
            )
        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3+self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers =[]
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids, opt.content_layer_weight, opt.style_layer_weight, opt.shifted_style)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.optim = torch.optim.Adam(self.netT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(use_lsgan=opt.which_gan=='lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
        elif opt.continue_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_GAN:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Esempio n. 5
0
    def initialize(self, opt):
        super(EncoderDecoderFramework, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################
        if opt.use_shape:
            self.encoder_type = 'shape'
            self.encoder_name = 'shape_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_edge:
            self.encoder_type = 'edge'
            self.encoder_name = 'edge_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_color:
            self.encoder_type = 'color'
            self.encoder_name = 'color_encoder'
            self.decoder_name = 'decoder'
        else:
            raise ValueError(
                'either use_shape, use_edge, use_color should be set')

        # encoder
        self.encoder = networks.define_image_encoder(opt, self.encoder_type)

        # decoder
        if self.encoder_type == 'shape':
            ndowns = opt.shape_ndowns
            nf = opt.shape_nf
            nof = opt.shape_nof
            output_nc = 7
            output_activation = None
            assert opt.decode_guided == False
        elif self.encoder_type == 'edge':
            ndowns = opt.edge_ndowns
            nf = opt.edge_nf
            nof = opt.edge_nof
            output_nc = 1
            output_activation = None
        elif self.encoder_type == 'color':
            ndowns = opt.color_ndowns
            nf = opt.color_nf
            nof = opt.color_nof
            output_nc = 3
            output_activation = nn.Tanh

        if opt.encoder_type in {'normal', 'st'}:
            self.feat_size = 256 // 2**(opt.edge_ndowns)
            self.mid_feat_size = self.feat_size
        else:
            self.feat_size = 1
            self.mid_feat_size = 8

        self.use_concat_net = False
        if opt.decode_guided:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof + opt.shape_nc,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=5,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
                self.concat_net = networks.FeatureConcatNetwork(
                    feat_nc=nof,
                    guide_nc=opt.shape_nc,
                    nblocks=3,
                    norm=opt.norm,
                    gpu_ids=opt.gpu_ids)
                if len(self.gpu_ids) > 0:
                    self.concat_net.cuda()
                networks.init_weights(self.concat_net, opt.init_type)
                self.use_concat_net = True
                print('encoder_decoder contains a feature_concat_network!')
        else:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=8,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)

        if not self.is_train or (self.is_train and self.opt.continue_train):
            self.load_network(self.encoder, self.encoder_name, opt.which_opoch)
            self.load_network(self.decoder, self.decoder_name, opt.which_opoch)
            if self.use_concat_net:
                self.load_network(self.concat_net, 'concat_net',
                                  opt.which_opoch)

        # loss functions
        self.loss_functions = []
        self.schedulers = []
        self.crit_L1 = networks.SmoothLoss(nn.L1Loss())
        self.crit_CE = networks.SmoothLoss(nn.CrossEntropyLoss())
        self.loss_functions += [self.crit_L1, self.crit_CE]

        self.optim = torch.optim.Adam([{
            'params': self.encoder.parameters()
        }, {
            'params': self.decoder.parameters()
        }],
                                      lr=opt.lr,
                                      betas=(opt.beta1, opt.beta2))
        self.optimizers = [self.optim]
        for optim in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(FlowRegressionModel, self).initialize(opt)
        ###################################
        # define flow networks
        ###################################
        if opt.which_model == 'unet':
            self.netF = networks.FlowUnet(
                input_nc=self.get_input_dim(opt.input_type1) +
                self.get_input_dim(opt.input_type2),
                nf=opt.nf,
                start_scale=opt.start_scale,
                num_scale=opt.num_scale,
                norm=opt.norm,
                gpu_ids=opt.gpu_ids,
            )
        elif opt.which_model == 'unet_v2':
            self.netF = networks.FlowUnet_v2(
                input_nc=self.get_input_dim(opt.input_type1) +
                self.get_input_dim(opt.input_type2),
                nf=opt.nf,
                max_nf=opt.max_nf,
                start_scale=opt.start_scale,
                num_scales=opt.num_scale,
                norm=opt.norm,
                gpu_ids=opt.gpu_ids,
            )
        if opt.gpu_ids:
            self.netF.cuda()
        networks.init_weights(self.netF, init_type=opt.init_type)
        ###################################
        # loss and optimizers
        ###################################
        self.crit_flow = networks.MultiScaleFlowLoss(
            start_scale=opt.start_scale,
            num_scale=opt.num_scale,
            loss_type=opt.flow_loss_type)
        self.crit_vis = nn.CrossEntropyLoss(
        )  #(0-visible, 1-invisible, 2-background)
        if opt.use_ss_flow_loss:
            self.crit_flow_ss = networks.SS_FlowLoss(loss_type='l1')
        if self.is_train:
            self.optimizers = []
            self.optim = torch.optim.Adam(self.netF.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2),
                                          weight_decay=opt.weight_decay)
            self.optimizers.append(self.optim)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            # load trained model for test
            print('load pretrained model')
            self.load_network(self.netF, 'netF', opt.which_epoch)
        elif opt.resume_train:
            # resume training
            print('resume training')
            self.load_network(self.netF, 'netF', opt.last_epoch)
            self.load_optim(self.optim, 'optim', opt.last_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
def main(config, needs_save, i):

    if config.run.visible_devices:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices

    assert config.train_dataset.root_dir_path == config.val_dataset.root_dir_path
    # train_patient_ids, val_patient_ids = divide_patients(config.train_dataset.root_dir_path)
    train_patient_ids, val_patient_ids = get_cv_splits(
        config.train_dataset.root_dir_path, i)

    seed = check_manual_seed()
    print('Using seed: {}'.format(seed))

    class_name_to_index = config.label_to_id._asdict()
    index_to_class_name = {v: k for k, v in class_name_to_index.items()}

    train_data_loader = get_data_loader(
        mode='train',
        dataset_name=config.train_dataset.dataset_name,
        root_dir_path=config.train_dataset.root_dir_path,
        patient_ids=train_patient_ids,
        batch_size=config.train_dataset.batch_size,
        num_workers=config.train_dataset.num_workers,
        volume_size=config.train_dataset.volume_size,
    )

    val_data_loader = get_data_loader(
        mode='val',
        dataset_name=config.val_dataset.dataset_name,
        root_dir_path=config.val_dataset.root_dir_path,
        patient_ids=val_patient_ids,
        batch_size=config.val_dataset.batch_size,
        num_workers=config.val_dataset.num_workers,
        volume_size=config.val_dataset.volume_size,
    )

    model = ResUNet(
        input_dim=config.model.input_dim,
        output_dim=config.model.output_dim,
        filters=config.model.filters,
    )

    print(model)

    if config.run.use_cuda:
        model.cuda()
        model = nn.DataParallel(model)

    if config.model.saved_model:
        print('Loading saved model: {}'.format(config.model.saved_model))
        model.load_state_dict(torch.load(config.model.saved_model))
    else:
        print('Initializing weights.')
        init_weights(model, init_type=config.model.init_type)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=config.optimizer.lr,
                           betas=config.optimizer.betas,
                           weight_decay=config.optimizer.weight_decay)

    dice_loss = SoftDiceLoss()

    focal_loss = FocalLoss(
        gamma=config.focal_loss.gamma,
        alpha=config.focal_loss.alpha,
    )

    active_contour_loss = ActiveContourLoss(
        weight=config.active_contour_loss.weight, )

    dice_coeff = DiceCoefficient(
        n_classes=config.metric.n_classes,
        index_to_class_name=index_to_class_name,
    )

    one_hot_encoder = OneHotEncoder(
        n_classes=config.metric.n_classes, ).forward

    def train(engine, batch):
        adjust_learning_rate(optimizer,
                             engine.state.epoch,
                             initial_lr=config.optimizer.lr,
                             n_epochs=config.run.n_epochs,
                             gamma=config.optimizer.gamma)

        model.train()

        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()

        else:
            image = image.float()
            label = label.long()

        optimizer.zero_grad()

        output = model(image)
        target = one_hot_encoder(label)[:, 1:, ...]

        l_dice = dice_loss(output, target)
        l_focal = focal_loss(output, target)
        l_active_contour = active_contour_loss(output, target)

        l_total = l_dice + l_focal + l_active_contour
        l_total.backward()

        optimizer.step()

        m_dice = dice_coeff.update(output.detach(), label)

        measures = {
            'SoftDiceLoss': l_dice.item(),
            'FocalLoss': l_focal.item(),
            'ActiveContourLoss': l_active_contour.item(),
        }

        measures.update(m_dice)

        if config.run.use_cuda:
            torch.cuda.synchronize()

        return measures

    def evaluate(engine, batch):
        model.eval()

        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()

        else:
            image = image.float()
            label = label.long()

        with torch.no_grad():
            output = model(image)
            target = one_hot_encoder(label)[:, 1:, ...]

            l_dice = dice_loss(output, target)
            l_focal = focal_loss(output, target)
            l_active_contour = active_contour_loss(output, target)

            m_dice = dice_coeff.update(output.detach(), label)

        measures = {
            'SoftDiceLoss': l_dice.item(),
            'FocalLoss': l_focal.item(),
            'ActiveContourLoss': l_active_contour.item(),
        }

        measures.update(m_dice)

        if config.run.use_cuda:
            torch.cuda.synchronize()

        return measures

    output_dir_path = get_output_dir_path(config, i)
    trainer = Engine(train)
    evaluator = Engine(evaluate)
    timer = Timer(average=True)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir_path,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.run.n_epochs + 1,
            create_dir=True,
        )

    monitoring_metrics = ['SoftDiceLoss', 'FocalLoss', 'ActiveContourLoss']
    monitoring_metrics += class_name_to_index.keys()

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    evaluator, metric)

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=monitoring_metrics)
    pbar.attach(evaluator, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def call_save_config(engine):
        if needs_save:
            return save_config(engine, config, seed, output_dir_path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def call_save_logs(engine):
        if needs_save:
            return save_logs('train', engine, config, output_dir_path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def call_print_times(engine):
        return print_times(engine, config, pbar, timer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        evaluator.run(val_data_loader, 1)

        if needs_save:
            save_logs('val', evaluator, config, output_dir_path)
            save_images(evaluator, trainer.state.epoch)

    def save_images(evaluator, epoch):
        batch = evaluator.state.batch
        image = batch['image']
        label = batch['label']

        if config.run.use_cuda:
            image = image.cuda(non_blocking=True).float()
            label = label.cuda(non_blocking=True).long()
        else:
            image = image.float()
            label = label.long()

        with torch.no_grad():
            pred = model(image)

        output = torch.ones_like(label)

        mask_0 = pred[:, 0, ...] < 0.5
        mask_1 = pred[:, 1, ...] < 0.5
        mask_2 = pred[:, 2, ...] < 0.5
        mask = mask_0 * mask_1 * mask_2

        pred = pred.argmax(1)
        output += pred

        output[mask] = 0

        image = image.detach().cpu().float()
        label = label.detach().cpu().unsqueeze(1).float()
        output = output.detach().cpu().unsqueeze(1).float()

        z_middle = image.shape[-1] // 2
        image = image[:, 0, ..., z_middle]
        label = label[:, 0, ..., z_middle]
        output = output[:, 0, ..., z_middle]

        if config.save.image_vmax is not None:
            vmax = config.save.image_vmax
        else:
            vmax = image.max()

        if config.save.image_vmin is not None:
            vmin = config.save.image_vmin
        else:
            vmin = image.min()

        image = np.clip(image, vmin, vmax)
        image -= vmin
        image /= (vmax - vmin)
        image *= 255.0

        save_path = os.path.join(output_dir_path,
                                 'result_{}.png'.format(epoch))
        save_images_via_plt(image, label, output, config.save.n_save_images,
                            config, save_path)

    if needs_save:
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'model': model,
                                      'optim': optimizer
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.run.n_epochs, config.run.n_epochs * len(train_data_loader)))

    trainer.run(train_data_loader, config.run.n_epochs)
    def initialize(self, opt):
        super(PoseTransferModel, self).initialize(opt)
        ###################################
        # define generator
        ###################################
        if opt.which_model_G == 'unet':
            self.netG = networks.UnetGenerator(
                input_nc=self.get_tensor_dim('+'.join(
                    [opt.G_appearance_type, opt.G_pose_type])),
                output_nc=3,
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_G == 'dual_unet':
            self.netG = networks.DualUnetGenerator(
                pose_nc=self.get_tensor_dim(opt.G_pose_type),
                appearance_nc=self.get_tensor_dim(opt.G_appearance_type),
                output_nc=3,
                aux_output_nc=[],
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                num_warp_scales=opt.G_n_warp_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                vis_mode=opt.G_vis_mode,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                no_end_norm=opt.G_no_end_norm,
                gpu_ids=opt.gpu_ids,
            )
        if opt.gpu_ids:
            self.netG.cuda()
        networks.init_weights(self.netG, init_type=opt.init_type)
        ###################################
        # define external pixel warper
        ###################################
        if opt.G_pix_warp:
            pix_warp_n_scale = opt.G_n_scale
            self.netPW = networks.UnetGenerator_MultiOutput(
                input_nc=self.get_tensor_dim(opt.G_pix_warp_input_type),
                output_nc=[1],  # only use one output branch (weight mask)
                nf=32,
                max_nf=128,
                num_scales=pix_warp_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids)
            if opt.gpu_ids:
                self.netPW.cuda()
            networks.init_weights(self.netPW, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_gan = self.is_train and self.opt.loss_weight_gan > 0
        if self.use_gan:
            self.netD = networks.NLayerDiscriminator(
                input_nc=self.get_tensor_dim(opt.D_input_type_real),
                ndf=opt.D_nf,
                n_layers=opt.D_n_layers,
                use_sigmoid=(opt.gan_type == 'dcgan'),
                output_bias=True,
                gpu_ids=opt.gpu_ids,
            )
            if opt.gpu_ids:
                self.netD.cuda()
            networks.init_weights(self.netD, init_type=opt.init_type)
        ###################################
        # load optical flow model
        ###################################
        if opt.flow_on_the_fly:
            self.netF = load_flow_network(opt.pretrained_flow_id,
                                          opt.pretrained_flow_epoch,
                                          opt.gpu_ids)
            self.netF.eval()
            if opt.gpu_ids:
                self.netF.cuda()
        ###################################
        # loss and optimizers
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.crit_vgg = networks.VGGLoss(
                opt.gpu_ids,
                shifted_style=opt.shifted_style_loss,
                content_weights=opt.vgg_content_weights)
            if opt.G_pix_warp:
                # only optimze netPW
                self.optim = torch.optim.Adam(self.netPW.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            else:
                self.optim = torch.optim.Adam(self.netG.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            self.optimizers = [self.optim]
            if self.use_gan:
                self.crit_gan = networks.GANLoss(
                    use_lsgan=(opt.gan_type == 'lsgan'))
                if self.gpu_ids:
                    self.crit_gan.cuda()
                self.optim_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=opt.lr_D,
                    betas=(opt.beta1, opt.beta2),
                    weight_decay=opt.weight_decay_D)
                self.optimizers += [self.optim_D]

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            # load trained model for testing
            self.load_network(self.netG, 'netG', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        elif opt.pretrained_G_id is not None:
            # load pretrained network
            self.load_network(self.netG, 'netG', opt.pretrained_G_epoch,
                              opt.pretrained_G_id)
        elif opt.resume_train:
            # resume training
            self.load_network(self.netG, 'netG', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_gan:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Esempio n. 9
0
    def __init__(self, opt):
        super(DASGIL, self).__init__()
        self.opt = opt
        self.isTrain = opt.isTrain
        if self.opt.gpu_ids >= 0:
            self.Tensor = torch.cuda.FloatTensor
        else:
            self.Tensor = torch.FloatTensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        # generator
        self.generator = Generator(opt)
        if opt.isTrain:
            self.generator.weight_init(0, 0.02)
            self.gen_parameters = list(self.generator.parameters())
            self.gen_optimizer = torch.optim.Adam(self.gen_parameters,
                                                  lr=opt.lr,
                                                  betas=(0.9, 0.999))
            # discriminator
            assert opt.dis_type == "FD" or opt.dis_type == "CD"
            if opt.dis_type == "FD":
                self.dis_f = FlattenDiscriminator(opt.dis_nc, opt.dis_nlayers)
            elif opt.dis_type == "CD":
                self.dis_f = CascadeDiscriminator()
            else:
                print("ERROR: only FD or CD is supported")

            init_weights(self.dis_f, 'normal', opt=self.opt)
            dis_params = list(self.dis_f.parameters())
            self.dis_optimizer = torch.optim.Adam(
                [p for p in dis_params if p.requires_grad],
                lr=opt.lr_dis,
                betas=(0.5, 0.9))
            weight = torch.ones(self.opt.num_classes)
            self.criterion_corssentropy = CrossEntropyLoss2d(weight)

            # initialize inputs
            self.input_GAN_real = self.Tensor(opt.batch_size, 3, opt.resized_h,
                                              opt.resized_w)
            self.input_rgb_A = self.Tensor(opt.batch_size, 3, opt.resized_h,
                                           opt.resized_w)
            self.input_depth_A = self.Tensor(opt.batch_size, 1, opt.resized_h,
                                             opt.resized_w)
            self.input_seg_A = self.Tensor(opt.batch_size, 3, opt.resized_h,
                                           opt.resized_w)

            self.input_rgb_A_prime = self.Tensor(opt.batch_size, 3,
                                                 opt.resized_h, opt.resized_w)
            self.input_depth_A_prime = self.Tensor(opt.batch_size, 1,
                                                   opt.resized_h,
                                                   opt.resized_w)
            self.input_seg_A_prime = self.Tensor(opt.batch_size, 3,
                                                 opt.resized_h, opt.resized_w)

            self.input_rgb_B = self.Tensor(opt.batch_size, 3, opt.resized_h,
                                           opt.resized_w)
            self.input_depth_B = self.Tensor(opt.batch_size, 1, opt.resized_h,
                                             opt.resized_w)
            self.input_seg_B = self.Tensor(opt.batch_size, 3, opt.resized_h,
                                           opt.resized_w)

        # load checkpoints
        if not opt.isTrain or opt.continue_train:
            save_filename_generator = '%d_net_%s' % (opt.which_epoch, 'gen')
            save_path_generator = os.path.join(self.save_dir,
                                               save_filename_generator)
            filename_generator = save_path_generator + '.pth'
            self.generator.load_state_dict(torch.load(filename_generator))
        if opt.continue_train:
            save_filename_dis = '%d_net_%s' % (opt.which_epoch, 'dis')
            save_path_dis = os.path.join(self.save_dir, save_filename_dis)
            filename_dis = save_path_dis + '.pth'
            self.dis_f.load_state_dict(torch.load(filename_dis))

        if opt.isTrain:
            # lr scheduler update
            self.gen_scheduler = lr_scheduler.StepLR(
                self.gen_optimizer,
                step_size=opt.step_lr_epoch,
                gamma=opt.gamma_lr)
            self.gen_scheduler.last_epoch = opt.which_epoch
            self.dis_scheduler = lr_scheduler.StepLR(
                self.dis_optimizer,
                step_size=opt.step_lr_epoch,
                gamma=opt.gamma_lr)
            self.dis_scheduler.last_epoch = opt.which_epoch
def main():

    # instantiate model and initialize weights
    model = ENet()
    networks.print_network(model)
    networks.init_weights(model, init_type='normal')
    model.init_convFilter(trainable=srm_trainable)

    if args.cuda:
        model.cuda()

    print('using pretrained model')
    checkpoint = torch.load(project_root + args.log_dir +
                            '/checkpoint_300.pth')
    model.load_state_dict(checkpoint['state_dict'])
    args.lr = args.lr * 0.001
    threshold = THRESHOLD_MAX

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    L1_criterion = nn.L1Loss(reduction='sum').cuda()

    if not srm_trainable:
        params = []
        for name, param in model.named_parameters():
            if name.find('convFilter1') == -1:
                params += [param]

        optimizer = create_optimizer(params, args.lr)
    else:
        optimizer = create_optimizer(model.parameters(), args.lr)

    nature_error_itr_global = []
    for itr in np.arange(1, 11):
        args.dataroot = dst_dir
        nature_error_itr_local = []

        # adding negative samples into the original training dataset
        construct_negative_samples(itr)

        train_loader = myDataset.DataLoaderHalf(
            myDataset.MyDataset(
                args,
                transforms.Compose([
                    transforms.RandomCrop(233),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(), normalize
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            half_constraint=True,
            sampler_type='RandomBalancedSampler',
            **kwargs)
        print('The number of train data:{}'.format(len(train_loader.dataset)))
        args.epochs = 15

        train_multi(train_loader, optimizer, model, criterion, L1_criterion, val_loader, itr, \
            nature_error_itr_local, nature_error_itr_global)

        # start from itr = 1
        if len(nature_error_itr_local) > 0:
            adv_model_num, adv_model_idx = adv_model_selection(
                nature_error_itr_local, threshold, itr)
            if adv_model_num < 1:
                break

    print(nature_error_itr_global)
    print(len(nature_error_itr_global) / (args.epochs - args.epochs // 2))
    final_model_selection(nature_error_itr_global, threshold)
Esempio n. 11
0
    def __init__(self,
                 input_nc=3,
                 ndf=64,
                 n_layers=3,
                 norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(
                norm_layer
        ) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func != nn.BatchNorm2d
        else:
            use_bias = norm_layer != nn.BatchNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1,
                       n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev,
                          ndf * nf_mult,
                          kernel_size=kw,
                          stride=2,
                          padding=padw,
                          bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev,
                      ndf * nf_mult,
                      kernel_size=kw,
                      stride=1,
                      padding=padw,
                      bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
        ]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)
        init_weights(self.model, 'xavier')
    def initialize(self, opt):
        super(TwoStagePoseTransferModel, self).initialize(opt)
        ###################################
        # load pretrained stage-1 (coarse) network
        ###################################
        self._create_stage_1_net(opt)
        ###################################
        # define stage-2 (refine) network
        ###################################
        # local patch encoder
        if opt.which_model_s2e == 'patch_embed':
            self.netT_s2e = networks.LocalPatchEncoder(
                n_patch=len(opt.patch_indices),
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2e_nf,
                max_nf=opt.s2e_max_nf,
                input_size=opt.patch_size,
                bottleneck_factor=opt.s2e_bottleneck_factor,
                n_residual_blocks=2,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof
        elif opt.which_model_s2e == 'patch':
            self.netT_s2e = networks.LocalPatchRearranger(
                n_patch=len(opt.patch_indices),
                image_size=opt.fine_size,
            )
            s2e_nof = 3
        elif opt.which_model_s2e == 'seg_embed':
            self.netT_s2e = networks.SegmentRegionEncoder(
                seg_nc=self.opt.seg_nc,
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2d_nf,
                input_size=opt.fine_size,
                n_blocks=3,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                grid_level=opt.s2e_grid_level,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof + opt.s2e_grid_level
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2e.cuda()

        # decoder
        if self.opt.which_model_s2d == 'resnet':
            self.netT_s2d = networks.ResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                n_blocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'unet':
            self.netT_s2d = networks.UnetGenerator_v2(
                input_nc=3 + s2e_nof,
                output_nc=3,
                num_downs=8,
                ngf=opt.s2d_nf,
                max_nf=opt.s2d_nf * 2**3,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'rpresnet':
            self.netT_s2d = networks.RegionPropagationResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                nblocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False)
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2d.cuda()
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers = []
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids,
                                                opt.content_layer_weight,
                                                opt.style_layer_weight,
                                                opt.shifted_style)

            self.optim = torch.optim.Adam([{
                'params': self.netT_s2e.parameters()
            }, {
                'params': self.netT_s2d.parameters()
            }],
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim)

            if opt.train_s1:
                self.optim_s1 = torch.optim.Adam(self.netT_s1.parameters(),
                                                 lr=opt.lr_s1,
                                                 betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_s1)

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
                self.fake_pool = ImagePool(opt.pool_size)
        ###################################
        # init/load model
        ###################################
        if self.is_train:
            if not opt.continue_train:
                self.load_network(self.netT_s1, 'netT', 'latest',
                                  self.opt_s1.id)
                networks.init_weights(self.netT_s2e, init_type=opt.init_type)
                networks.init_weights(self.netT_s2d, init_type=opt.init_type)
                if self.use_GAN:
                    networks.init_weights(self.netD, init_type=opt.init_type)
            else:
                self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
                self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
                self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
                self.load_optim(self.optim, 'optim', opt.which_epoch)
                if self.use_GAN:
                    self.load_network(self.netD, 'netD', opt.which_epoch)
                    self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        else:
            self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
            self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
            self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))