Beispiel #1
0
    def __init__(self, config):
        self.config = config

        # Data
        self.dataset_ss_train, _, self.dataset_ss_val = DatasetUtil.get_dataset_by_type(
            DatasetUtil.dataset_type_ss, self.config.ss_size, is_balance=self.config.is_balance_data,
            data_root=self.config.data_root_path, train_label_path=self.config.label_path, max_size=self.config.max_size)
        self.data_loader_ss_train = DataLoader(self.dataset_ss_train, self.config.ss_batch_size,
                                               True, num_workers=16, drop_last=True)
        self.data_loader_ss_val = DataLoader(self.dataset_ss_val, self.config.ss_batch_size,
                                             False, num_workers=16, drop_last=True)

        # Model
        self.net = self.config.Net(num_classes=self.config.ss_num_classes,
                                   output_stride=self.config.output_stride, arch=self.config.arch)

        if self.config.only_train_ss:
            self.net = BalancedDataParallel(0, self.net, dim=0).cuda()
        else:
            self.net = DataParallel(self.net).cuda()
            pass
        cudnn.benchmark = True

        # Optimize
        self.optimizer = optim.SGD(params=[
            {'params': self.net.module.model.backbone.parameters(), 'lr': self.config.ss_lr},
            {'params': self.net.module.model.classifier.parameters(), 'lr': self.config.ss_lr * 10},
        ], lr=self.config.ss_lr, momentum=0.9, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.config.ss_milestones, gamma=0.1)

        # Loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=255, reduction='mean').cuda()
        pass
Beispiel #2
0
def train_model(model, model_name, hyperparams, device, epochs):
    '''
    
    Train Model
    
    This is a generic function to call the model's training function. 
    
    '''

    print('Beginning Training for: ', model_name)
    print('------------------------------------')
    
    results = {}
    
    if torch.cuda.device_count() > 1: 
        print("Using ", torch.cuda.device_count(), " GPUs.")
        print('------------------------------------')
        model = DataParallel(model)
        
    model = model.to(device=device)
    
    optimizer = optim.Adam(model.parameters(), betas=hyperparams['betas'], lr=hyperparams['learning_rate'], weight_decay=hyperparams['L2_reg'])
    
    lr_updater = lr_scheduler.StepLR(optimizer, hyperparams['lr_decay_epochs'], hyperparams['lr_decay'])
    
    results = train(model, optimizer, lr_updater, results, epochs=epochs)
    
    plot_results(results, model_name ,save=True)
    np.save(model_name, results)
    
    return results
Beispiel #3
0
 def _make_model(self):
     # prepare network
     self.logger.info("Creating graph and optimizer...")
     model = get_pose_net(self.backbone, False, self.jointnum)
     model = DataParallel(model).cuda()
     model.load_state_dict(torch.load(self.modelpath)['network'])
     single_pytorch_model = model.module
     single_pytorch_model.eval()
     self.model = single_pytorch_model
Beispiel #4
0
def main():
    train_dataset = hrb_input.TrainDataset()
    val_dataset = hrb_input.ValidationDataset()
    test_dataset = hrb_input.TestDataset()

    train_loader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=8,
                            pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
                             num_workers=8,
                             pin_memory=True)

    model = network.resnet50(num_classes=1000)
    model = DataParallel(model, device_ids=[0, 1, 2])
    model.to(device)

    loss_func = F.cross_entropy
    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=LR_STEPS,
                                               gamma=LR_GAMMA)

    best_loss = 100.0
    patience_counter = 0

    for epoch in range(EPOCHS):

        train(model, optimizer, loss_func, train_loader, epoch)
        val_loss = validate(model, loss_func, val_loader)

        if val_loss < best_loss:
            torch.save(model, MODEL_PATH)
            print('Saving improved model')
            print()
            best_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            print('Epoch(s) since best model: ', patience_counter)
            print()
        if patience_counter >= EARLY_STOPPING_EPOCHS:
            print('Early Stopping ...')
            print()
            break
        scheduler.step()

    print('Predicting labels from best trained model')
    predict(test_loader)
 def run_training(self):
     self.maybe_update_lr(self.epoch)
     # amp must be initialized before DP
     ds = self.network.do_ds
     self.network.do_ds = True
     # self.network = DataParallel(self.network, tuple(range(self.num_gpus)), )
     self.network = DataParallel(self.network, device_ids = list(range(0, self.num_gpus)))
     
     ret = nnUNetTrainer.run_training(self)
     self.network = self.network.module
     self.network.do_ds = ds
     return ret
Beispiel #6
0
def pixel_fit_image(im3d,sS=3.,ss=1.5,th_brightness= 5,ksize_max=3,plt_val=False):
    input_= torch.tensor([[im3d]]).cuda()
    ### compute the big gaussian filter ##########
    gaussian_kernel_ = gaussian_kernel(sxyz = [sS,sS,sS],cut_off = 2.5)
    ksz = len(gaussian_kernel_)
    gaussian_kernel_ = torch.FloatTensor(gaussian_kernel_).cuda().view(1, 1, ksz,ksz,ksz)
    #gaussian_kernel_ = gaussian_kernel_.repeat(channels, 1, 1, 1)
    gfilt_big = DataParallel(nn.Conv3d(1,1,ksz, stride=1,padding=0,bias=False)).cuda()
    gfilt_big.module.weight.data = gaussian_kernel_
    gfilt_big.module.weight.requires_grad = False
    gfit_big_ = gfilt_big(pd.ReplicationPad3d(int(ksz/2.))(input_))

    ### compute the small gaussian filter ##########
    gaussian_kernel_ = gaussian_kernel(sxyz = [1,ss,ss],cut_off = 2.5)
    ksz = len(gaussian_kernel_)
    gaussian_kernel_ = torch.FloatTensor(gaussian_kernel_).cuda().view(1, 1, ksz,ksz,ksz)
    #gaussian_kernel_ = gaussian_kernel_.repeat(channels, 1, 1, 1)
    gfilt_sm = DataParallel(nn.Conv3d(1,1,ksz, stride=1,padding=0,bias=False)).cuda()
    gfilt_sm.module.weight.data = gaussian_kernel_
    gfilt_sm.module.weight.requires_grad = False
    gfilt_sm_ = gfilt_sm(pd.ReplicationPad3d(int(ksz/2.))(input_))

    ### compute the maximum filter ##########
    max_filt = DataParallel(nn.MaxPool3d(ksize_max, stride=1,padding=int(ksize_max/2), return_indices=False)).cuda()
    local_max = max_filt(gfilt_sm_)==gfilt_sm_

    g_dif = torch.log(gfilt_sm_)-torch.log(gfit_big_)
    std_ = torch.std(g_dif)
    g_keep = (g_dif>std_*th_brightness)*local_max
    
    #inds = torch.nonzero(g_dif)
    g_keep = g_keep.cpu().numpy()
    inds = np.array(np.nonzero(g_keep)).T
    
    zxyhf = np.array([[],[],[],[]]).T
    zf,xf,yf,hf = zxyhf.T 
    if len(inds):
        brightness = g_dif[inds[:,0],inds[:,1],inds[:,2],inds[:,3],inds[:,4]]
        # bring back to CPU
        zf,xf,yf = inds[:,-3:].T#.cpu().numpy().T
        hf = brightness.cpu().numpy()
        zxyhf = np.array([zf,xf,yf,hf]).T
    torch.cuda.empty_cache()
    if plt_val:
        plt.figure()
        plt.scatter(yf,xf,s=150,facecolor='none',edgecolor='r')
        plt.imshow(np.max(im3d,axis=0),vmax=2)
        plt.show()
    return zxyhf
Beispiel #7
0
    def _make_model(self):
        # prepare network
        self.logger.info("Creating graph and optimizer...")
        model = get_model('train', self.joint_num)
        model = DataParallel(model).cuda()
        optimizer = self.get_optimizer(model)
        if cfg.continue_train:
            start_epoch, model, optimizer = self.load_model(model, optimizer)
        else:
            start_epoch = 0
        model.train()

        self.start_epoch = start_epoch
        self.model = model
        self.optimizer = optimizer
    def run_training(self):
        self.maybe_update_lr(self.epoch)

        # amp must be initialized before DP

        ds = self.network.do_ds
        self.network.do_ds = True
        self.network = DataParallel(
            self.network,
            tuple(range(self.num_gpus)),
        )
        ret = tuframeworkTrainer.run_training(self)
        self.network = self.network.module
        self.network.do_ds = ds
        return ret
Beispiel #9
0
def med_correct(im3d,ksize = 32,cuda=True):
    if cuda:
        sz = im3d.shape[-2:]
        input_= torch.tensor(np.array([im3d],dtype=np.float32)).cuda()
        med2d_filt = DataParallel(MedianPool2d(ksize, stride=ksize)).cuda()
        output_ = med2d_filt(input_)
        imf = nn.functional.interpolate(output_,size =sz,mode='bilinear').cpu().numpy()[0]
        return im3d/imf
    else:
        sz,sxo,syo = im3d.shape
        #assert(sx/ksize==int(sx/ksize) and sy/ksize==int(sy/ksize))
        num_windows_x = int(sxo/ksize)
        num_windows_y = int(syo/ksize)
        sx = num_windows_x*ksize
        sy = num_windows_y*ksize
        im = im3d[:,:sx,:sy]

        im_reshape = im.reshape([sz,num_windows_x,ksize,num_windows_y,ksize])
        im_reshape = np.swapaxes(im_reshape,2,3)
        im_reshape = im_reshape.reshape(list(im_reshape.shape[:-2])+[ksize*ksize])
        im_med = np.median(im_reshape,axis=-1)
        sz,sx_,sy_ = im_med.shape
        
        im_medf = zoom(im_med,[1,float(sxo)/sx_,float(syo)/sy_],order=1)
        return im3d/imf
Beispiel #10
0
    def initialize_network(self):
        """
        replace genericUNet with the implementation of above for super speeds
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet_DP(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op,
            dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper
Beispiel #11
0
 def test_stacked_self_attention_can_run_foward_on_multiple_gpus(self):
     encoder = StackedSelfAttentionEncoder(input_dim=9,
                                           hidden_dim=12,
                                           projection_dim=9,
                                           feedforward_hidden_dim=5,
                                           num_layers=3,
                                           num_attention_heads=3).to(0)
     parallel_encoder = DataParallel(encoder, device_ids=[0, 1])
     inputs = torch.randn([3, 5, 9]).to(0)
     encoder_output = parallel_encoder(inputs, None)
     assert list(encoder_output.size()) == [3, 5, 12]
Beispiel #12
0
    def __init__(self, weightsPath, principal_points=None, focal=(1500, 1500)):
        """

        :param weightsPath:
        :param principal_points:
        :param focal:
        """

        self.focal = focal
        self.principal_points = principal_points

        self.net = get_pose_net(cfg, False)
        self.net = DataParallel(self.net).cuda()
        weigths = torch.load(weightsPath)
        self.net.load_state_dict(weigths['network'])
        self.net.eval()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
        ])
Beispiel #13
0
def get_model(model, gpus=None, num_classes=1000, train_aug="default"):
    checkpoint = _checkpoints[model][train_aug]
    print(model, train_aug, checkpoint)
    if 'efficientnet-l2t' in model:
        norm_type = os.environ.get('norm', 'batch')
        model = EfficientNet.from_name(model,
                                       num_classes=num_classes,
                                       multiple_feat=True,
                                       norm_type=norm_type)  # instance norm

    # IMAGENET -----------------------------------------------------
    elif model == 'resnet50':
        model = resnet50(pretrained=True)
    else:
        raise ValueError(model)

    if checkpoint and checkpoint != 'modelzoo':
        state = torch.load(checkpoint)
        key = 'model' if 'model' in state else 'state_dict'
        if key in state and not isinstance(state[key], dict):
            key = 'state_dict'
        if 'omem' in train_aug:
            key = 'ema'
        print('model epoch=', state.get('epoch', -1))
        if key in state:
            model.load_state_dict(
                {k.replace('module.', ''): v
                 for k, v in state[key].items()})
        else:
            model.load_state_dict(
                {k.replace('module.', ''): v
                 for k, v in state.items()})  # without key

    if gpus not in ['cpu', None]:
        if len(gpus) > 1:
            model = DataParallel(model, device_ids=gpus)
        model = model.cuda()

    return model
Beispiel #14
0
    def set_device(self, device):
        device = cast_device(device)
        str_device = device_to_str(device)
        nn_module = self.get_nn_module()

        if isinstance(device, (list, tuple)):
            device_ids = []
            for dev in device:
                if dev.type != 'cuda':
                    raise ValueError
                if dev.index is None:
                    raise ValueError
                device_ids.append(dev.index)
            if len(device_ids) != len(set(device_ids)):
                raise ValueError("Cuda device indices must be unique")
            nn_module = DataParallel(nn_module, device_ids=device_ids)
            device = device[0]

        self.params['device'] = str_device
        self.device = device
        self.nn_module = nn_module.to(self.device)
        if self.loss is not default:
            self.loss = self.loss.to(self.device)
Beispiel #15
0
  def __init__(self, network,
               w_lr=0.01,
               w_mom=0.9,
               w_wd=1e-4,
               t_lr=0.001,
               t_wd=3e-3,
               t_beta=(0.5, 0.999),
               init_temperature=5.0,
               temperature_decay=0.965,
               logger=logging,
               lr_scheduler={'T_max' : 200},
               gpus=[0],
               save_theta_prefix='',
               resource_weight=0.001):
    assert isinstance(network, SNAS)
    network.apply(weights_init)
    network = network.train().cuda()
    self._criterion = nn.CrossEntropyLoss().cuda()

    alpha_params = network.arch_parameters()
    mod_params = network.model_parameters()
    self.alpha = alpha_params
    if isinstance(gpus, str):
      gpus = [int(i) for i in gpus.strip().split(',')]
    network = DataParallel(network, gpus)
    self._mod = network
    self.gpus = gpus

    self.w = mod_params
    self._tem_decay = temperature_decay
    self.temp = init_temperature
    self.logger = logger
    self.save_theta_prefix = save_theta_prefix
    self._resource_weight = resource_weight

    self._loss_avg = AvgrageMeter('loss')
    self._acc_avg = AvgrageMeter('acc')
    self._res_cons_avg = AvgrageMeter('resource-constraint')

    self.w_opt = torch.optim.SGD(
                    mod_params,
                    w_lr,
                    momentum=w_mom,
                    weight_decay=w_wd)
    self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)
    self.t_opt = torch.optim.Adam(
                    alpha_params,
                    lr=t_lr, betas=t_beta,
                    weight_decay=t_wd)
Beispiel #16
0
    def _make_model(self):
        model_path = os.path.join(cfg.model_dir, 'snapshot_%d.pth.tar' % self.test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        self.logger.info('Load checkpoint from {}'.format(model_path))
        
        # prepare network
        self.logger.info("Creating graph...")
        model = get_model(self.vertex_num, self.joint_num, 'test')
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'], strict=False)
        model.eval()

        self.model = model
Beispiel #17
0
    def _make_model(self, test_epoch):
        self.test_epoch = test_epoch
        model_path = os.path.join(cfg.model_dir,
                                  'snapshot_%d.pth.tar' % self.test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        # self.logger.info('Load checkpoint from {}'.format(model_path))

        # prepare network
        # self.logger.info("Creating graph...")
        model = get_pose_net(self.backbone, False, self.joint_num)
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'])
        model.eval()

        self.model = model
Beispiel #18
0
class nnUNetTrainerV2CascadeFullRes_DP(nnUNetTrainerV2CascadeFullRes):
    def __init__(self,
                 plans_file,
                 fold,
                 output_folder=None,
                 dataset_directory=None,
                 batch_dice=True,
                 stage=None,
                 unpack_data=True,
                 deterministic=True,
                 num_gpus=1,
                 distribute_batch_size=False,
                 fp16=False,
                 previous_trainer="nnUNetTrainerV2_DP"):
        super().__init__(plans_file, fold, output_folder, dataset_directory,
                         batch_dice, stage, unpack_data, deterministic,
                         previous_trainer, fp16)
        self.init_args = (plans_file, fold, output_folder, dataset_directory,
                          batch_dice, stage, unpack_data, deterministic,
                          num_gpus, distribute_batch_size, fp16,
                          previous_trainer)

        self.num_gpus = num_gpus
        self.distribute_batch_size = distribute_batch_size
        self.dice_do_BG = False
        self.dice_smooth = 1e-5
        if self.output_folder is not None:
            task = self.output_folder.split("/")[-3]
            plans_identifier = self.output_folder.split("/")[-2].split(
                "__")[-1]
            folder_with_segs_prev_stage = join(
                network_training_output_dir, "3d_lowres", task,
                previous_trainer + "__" + plans_identifier, "pred_next_stage")
            self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage
        else:
            self.folder_with_segs_from_prev_stage = None
        print(self.folder_with_segs_from_prev_stage)

    def get_basic_generators(self):
        self.load_dataset()
        self.do_split()
        if self.threeD:
            dl_tr = DataLoader3D(self.dataset_tr,
                                 self.basic_generator_patch_size,
                                 self.patch_size,
                                 self.batch_size,
                                 True,
                                 oversample_foreground_percent=self.
                                 oversample_foreground_percent,
                                 pad_mode="constant",
                                 pad_sides=self.pad_all_sides)
            dl_val = DataLoader3D(self.dataset_val,
                                  self.patch_size,
                                  self.patch_size,
                                  self.batch_size,
                                  True,
                                  oversample_foreground_percent=self.
                                  oversample_foreground_percent,
                                  pad_mode="constant",
                                  pad_sides=self.pad_all_sides)
        else:
            raise NotImplementedError("2D has no cascade")

        return dl_tr, dl_val

    def process_plans(self, plans):
        super().process_plans(plans)
        if not self.distribute_batch_size:
            self.batch_size = self.num_gpus * self.plans['plans_per_stage'][
                self.stage]['batch_size']
        else:
            if self.batch_size < self.num_gpus:
                print(
                    "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well"
                )
            elif self.batch_size % self.num_gpus != 0:
                print(
                    "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well"
                )

    def initialize(self, training=True, force_load_plans=False):
        if not self.was_initialized:
            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)
            self.setup_DA_params()

            ################# Here we wrap the loss for deep supervision ############
            net_numpool = len(self.net_num_pool_op_kernel_sizes)
            weights = np.array([1 / (2**i) for i in range(net_numpool)])
            mask = np.array([
                True if i < net_numpool - 1 else False
                for i in range(net_numpool)
            ])
            weights[~mask] = 0
            weights = weights / weights.sum()
            self.loss_weights = weights
            ################# END ###################

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)

            if training:
                if not isdir(self.folder_with_segs_from_prev_stage):
                    raise RuntimeError(
                        "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the "
                        "segmentations for the next stage")

                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network,
                              (SegmentationNetwork, DataParallel))
        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )

        self.was_initialized = True

    def initialize_network(self):
        """
        replace genericUNet with the implementation of above for super speeds
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet_DP(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op,
            dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper

    def run_training(self):
        self.maybe_update_lr(self.epoch)
        # amp must be initialized before DP
        ds = self.network.do_ds
        self.network.do_ds = True
        # self.network = DataParallel(self.network, tuple(range(self.num_gpus)), )
        self.network = DataParallel(self.network,
                                    device_ids=list(range(0, self.num_gpus)))

        ret = nnUNetTrainer.run_training(self)
        self.network = self.network.module
        self.network.do_ds = ds
        return ret

    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                ret = self.network(data,
                                   target,
                                   return_hard_tp_fp_fn=run_online_evaluation)
                if run_online_evaluation:
                    ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                    self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
                else:
                    ces, tps, fps, fns = ret
                del data, target
                l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            ret = self.network(data,
                               target,
                               return_hard_tp_fp_fn=run_online_evaluation)
            if run_online_evaluation:
                ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
            else:
                ces, tps, fps, fns = ret
            del data, target
            l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                l.backward()
                clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        return l.detach().cpu().numpy()

    def run_online_evaluation(self, tp_hard, fp_hard, fn_hard):
        tp_hard = tp_hard.detach().cpu().numpy().mean(0)
        fp_hard = fp_hard.detach().cpu().numpy().mean(0)
        fn_hard = fn_hard.detach().cpu().numpy().mean(0)
        self.online_eval_foreground_dc.append(
            list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))

    def compute_loss(self, ces, tps, fps, fns):
        loss = None
        for i in range(len(ces)):
            if not self.dice_do_BG:
                tp = tps[i][:, 1:]
                fp = fps[i][:, 1:]
                fn = fns[i][:, 1:]
            else:
                tp = tps[i]
                fp = fps[i]
                fn = fns[i]

            if self.batch_dice:
                tp = tp.sum(0)
                fp = fp.sum(0)
                fn = fn.sum(0)
            else:
                pass

            nominator = 2 * tp + self.dice_smooth
            denominator = 2 * tp + fp + fn + self.dice_smooth

            dice_loss = (-nominator / denominator).mean()
            if loss is None:
                loss = self.loss_weights[i] * (ces[i].mean() + dice_loss)
            else:
                loss += self.loss_weights[i] * (ces[i].mean() + dice_loss)
        return loss
Beispiel #19
0

def cal_acc(pred, Y):
    return (pred.argmax(dim=1) == Y).sum() / pred.shape[0]


model_root = './models'
root = './data/patches/phaseA128'
num_classes = 10
MAX_EPOCHES = 300

if __name__ == "__main__":
    device = device('cuda' if cuda.is_available() else 'cpu')
    model = PhaseANet(num_classes=num_classes)
    model = model.to(device)
    model = DataParallel(module=model, device_ids=[0, 1])

    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            for name, param in m.named_parameters():
                if 'weight' in name:
                    init.xavier_normal_(param, )
                    # print(name, param.data)
                elif 'bias' in name:
                    init.constant_(param, 0)
                    # print(name, param.data)

    initial_epoch, state_dict = util.findLastCheckpoint(model_root)
    if state_dict is not None:
        model.load_state_dict(state_dict)
Beispiel #20
0
joints_name = ('Head_top', 'Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist',
               'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Hip', 'R_Knee',
               'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Pelvis', 'Spine',
               'Head', 'R_Hand', 'L_Hand', 'R_Toe', 'L_Toe')
flip_pairs = ((2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13), (17, 18),
              (19, 20))
skeleton = ((0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9),
            (9, 10), (10, 19), (11, 12), (12, 13), (13, 20), (1, 2), (2, 3),
            (3, 4), (4, 17), (1, 5), (5, 6), (6, 7), (7, 18))

# snapshot load
model_path = './snapshot_%d.pth.tar' % int(args.test_epoch)
assert osp.exists(model_path), 'Cannot find model at ' + model_path
print('Load checkpoint from {}'.format(model_path))
model = get_pose_net(cfg, False, joint_num)
model = DataParallel(model).cuda()
ckpt = torch.load(model_path)
model.load_state_dict(ckpt['network'])
model.eval()

# prepare input image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
])
img_path = 'input.jpg'
original_img = cv2.imread(img_path)
original_img_height, original_img_width = original_img.shape[:2]

# prepare bbox
bbox_list = [
Beispiel #21
0
class SSRunner(object):
    def __init__(self, config):
        self.config = config

        # Data
        self.dataset_ss_train, _, self.dataset_ss_val = DatasetUtil.get_dataset_by_type(
            DatasetUtil.dataset_type_ss,
            self.config.ss_size,
            is_balance=self.config.is_balance_data,
            data_root=self.config.data_root_path,
            train_label_path=self.config.label_path,
            max_size=self.config.max_size)
        self.data_loader_ss_train = DataLoader(self.dataset_ss_train,
                                               self.config.ss_batch_size,
                                               True,
                                               num_workers=16,
                                               drop_last=True)
        self.data_loader_ss_val = DataLoader(self.dataset_ss_val,
                                             self.config.ss_batch_size,
                                             False,
                                             num_workers=16,
                                             drop_last=True)

        # Model
        self.net = self.config.Net(num_classes=self.config.ss_num_classes,
                                   output_stride=self.config.output_stride,
                                   arch=self.config.arch)

        if self.config.only_train_ss:
            self.net = BalancedDataParallel(0, self.net, dim=0).cuda()
        else:
            self.net = DataParallel(self.net).cuda()
            pass
        cudnn.benchmark = True

        # Optimize
        self.optimizer = optim.SGD(params=[
            {
                'params': self.net.module.model.backbone.parameters(),
                'lr': self.config.ss_lr
            },
            {
                'params': self.net.module.model.classifier.parameters(),
                'lr': self.config.ss_lr * 10
            },
        ],
                                   lr=self.config.ss_lr,
                                   momentum=0.9,
                                   weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.config.ss_milestones, gamma=0.1)

        # Loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=255,
                                           reduction='mean').cuda()
        pass

    def train_ss(self, start_epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        # self.eval_ss(epoch=0)
        best_iou = 0.0

        for epoch in range(start_epoch, self.config.ss_epoch_num):
            Tools.print()
            Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr'],
                self.optimizer.param_groups[1]['lr']),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 1 训练模型
            all_loss = 0.0
            self.net.train()
            if self.config.is_balance_data:
                self.dataset_ss_train.reset()
                pass
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_train),
                                    total=len(self.data_loader_ss_train)):
                inputs, labels = inputs.float().cuda(), labels.long().cuda()
                self.optimizer.zero_grad()

                result = self.net(inputs)
                loss = self.ce_loss(result, labels)

                loss.backward()
                self.optimizer.step()

                all_loss += loss.item()

                if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0:
                    score = self.eval_ss(epoch=epoch)
                    mean_iou = score["Mean IoU"]
                    if mean_iou > best_iou:
                        best_iou = mean_iou
                        save_file_name = Tools.new_dir(
                            os.path.join(
                                self.config.ss_model_dir,
                                "ss_{}_{}_{}.pth".format(epoch, i, best_iou)))
                        torch.save(self.net.state_dict(), save_file_name)
                        Tools.print("Save Model to {}".format(save_file_name),
                                    txt_path=self.config.ss_save_result_txt)
                        Tools.print()
                    pass
                pass
            self.scheduler.step()
            ###########################################################################

            Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format(
                epoch, self.config.ss_epoch_num,
                all_loss / len(self.data_loader_ss_train)),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.ss_save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.ss_model_dir,
                                 "ss_{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.ss_save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.ss_eval_epoch_freq == 0:
                score = self.eval_ss(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.ss_model_dir,
                         "ss_final_{}.pth".format(self.config.ss_epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.ss_save_result_txt)
        Tools.print()

        self.eval_ss(epoch=self.config.ss_epoch_num)
        pass

    def eval_ss(self, epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_val),
                                    total=len(self.data_loader_ss_val)):
                inputs = inputs.float().cuda()
                labels = labels.long().cuda()
                outputs = self.net(inputs)
                preds = outputs.detach().max(dim=1)[1].cpu().numpy()
                targets = labels.cpu().numpy()

                metrics.update(targets, preds)
                pass
            pass

        score = metrics.get_results()
        Tools.print("{} {}".format(epoch, metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def inference_ss(self,
                     model_file_name=None,
                     data_loader=None,
                     save_path=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        final_save_path = Tools.new_dir("{}_final".format(save_path))

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs, labels,
                    image_info_list) in tqdm(enumerate(data_loader),
                                             total=len(data_loader)):
                assert len(image_info_list) == 1

                # 标签
                max_size = 1000
                size = Image.open(image_info_list[0]).size
                basename = os.path.basename(image_info_list[0])
                final_name = os.path.join(final_save_path,
                                          basename.replace(".JPEG", ".png"))
                if os.path.exists(final_name):
                    continue

                if size[0] < max_size and size[1] < max_size:
                    targets = F.interpolate(torch.unsqueeze(
                        labels[0].float().cuda(), dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest").detach().cpu()
                else:
                    targets = F.interpolate(torch.unsqueeze(labels[0].float(),
                                                            dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest")
                targets = targets[0].long().numpy()

                # 预测
                outputs = 0
                for input_index, input_one in enumerate(inputs):
                    output_one = self.net(input_one.float().cuda())
                    if size[0] < max_size and size[1] < max_size:
                        outputs += F.interpolate(
                            output_one,
                            size=(size[1], size[0]),
                            mode="bilinear",
                            align_corners=False).detach().cpu()
                    else:
                        outputs += F.interpolate(output_one.detach().cpu(),
                                                 size=(size[1], size[0]),
                                                 mode="bilinear",
                                                 align_corners=False)
                        pass
                    pass
                outputs = outputs / len(inputs)
                preds = outputs.max(dim=1)[1].numpy()

                # 计算
                metrics.update(targets, preds)

                if save_path:
                    Image.open(image_info_list[0]).save(
                        os.path.join(save_path, basename))
                    DataUtil.gray_to_color(
                        np.asarray(targets[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", "_l.png")))
                    DataUtil.gray_to_color(np.asarray(
                        preds[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", ".png")))
                    Image.fromarray(np.asarray(
                        preds[0], dtype=np.uint8)).save(final_name)
                    pass
                pass
            pass

        score = metrics.get_results()
        Tools.print("{}".format(metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def load_model(self, model_file_name):
        Tools.print("Load model form {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        checkpoint = torch.load(model_file_name)

        if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1:
            # checkpoint = {key.replace("module.", ""): checkpoint[key] for key in checkpoint}
            pass

        self.net.load_state_dict(checkpoint, strict=True)
        Tools.print("Restore from {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        pass

    def stat(self):
        stat(self.net, (3, self.config.ss_size, self.config.ss_size))
        pass

    pass
Beispiel #22
0
def train(args, pt_dir, chkpt_path, trainloader, devloader, writer, logger, hp,
          hp_str):

    model = get_SLOCountNet(hp).cuda()

    print("FOV: {}", model.get_fov(hp.features.n_fft))
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("N_parameters : {}".format(params))
    model = DataParallel(model)

    if hp.train.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam)
    else:
        raise Exception("%s optimizer not supported" % hp.train.optimizer)

    epoch = 0
    best_loss = np.inf

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['step']

        # will use new given hparams.
        if hp_str != checkpoint['hp_str']:
            logger.warning("New hparams is different from checkpoint.")
    else:
        logger.info("Starting new training run")

    try:

        for epoch in range(epoch, hp.train.n_epochs):

            vad_scores = Binarymetrics.BinaryMeter()  # activity scores
            vod_scores = Binarymetrics.BinaryMeter()  # overlap scores
            count_scores = Binarymetrics.MultiMeter()  # Countnet scores

            model.train()
            tot_loss = 0

            with tqdm(trainloader) as t:
                t.set_description("Epoch: {}".format(epoch))

                for count, batch in enumerate(trainloader):

                    features, labels = batch
                    features = features.cuda()
                    labels = labels.cuda()

                    preds = model(features)

                    loss = criterion(preds, labels)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # compute proper metrics for VAD
                    loss = loss.item()

                    if loss > 1e8 or math.isnan(loss):  # check if exploded
                        logger.error("Loss exploded to %.02f at step %d!" %
                                     (loss, epoch))
                        raise Exception("Loss exploded")

                    VADpreds = torch.sum(torch.exp(preds[:, 1:5, :]),
                                         dim=1).unsqueeze(1)
                    VADlabels = torch.sum(labels[:, 1:5, :],
                                          dim=1).unsqueeze(1)
                    vad_scores.update(VADpreds, VADlabels)

                    VODpreds = torch.sum(torch.exp(preds[:, 2:5, :]),
                                         dim=1).unsqueeze(1)
                    VODlabels = torch.sum(labels[:, 2:5, :],
                                          dim=1).unsqueeze(1)
                    vod_scores.update(VODpreds, VODlabels)

                    count_scores.update(
                        torch.argmax(torch.exp(preds), 1).unsqueeze(1),
                        torch.argmax(labels, 1).unsqueeze(1))

                    tot_loss += loss

                    vad_fa = vad_scores.get_fa().item()
                    vad_miss = vad_scores.get_miss().item()
                    vad_precision = vad_scores.get_precision().item()
                    vad_recall = vad_scores.get_recall().item()
                    vad_matt = vad_scores.get_matt().item()
                    vad_f1 = vad_scores.get_f1().item()
                    vad_tp = vad_scores.tp.item()
                    vad_tn = vad_scores.tn.item()
                    vad_fp = vad_scores.fp.item()
                    vad_fn = vad_scores.fn.item()

                    vod_fa = vod_scores.get_fa().item()
                    vod_miss = vod_scores.get_miss().item()
                    vod_precision = vod_scores.get_precision().item()
                    vod_recall = vod_scores.get_recall().item()
                    vod_matt = vod_scores.get_matt().item()
                    vod_f1 = vod_scores.get_f1().item()
                    vod_tp = vod_scores.tp.item()
                    vod_tn = vod_scores.tn.item()
                    vod_fp = vod_scores.fp.item()
                    vod_fn = vod_scores.fn.item()

                    count_fa = count_scores.get_accuracy().item()
                    count_miss = count_scores.get_miss().item()
                    count_precision = count_scores.get_precision().item()
                    count_recall = count_scores.get_recall().item()
                    count_matt = count_scores.get_matt().item()
                    count_f1 = count_scores.get_f1().item()
                    count_tp = count_scores.get_tp().item()
                    count_tn = count_scores.get_tn().item()
                    count_fp = count_scores.get_fp().item()
                    count_fn = count_scores.get_fn().item()

                    t.set_postfix(loss=tot_loss / (count + 1),
                                  vad_miss=vad_miss,
                                  vad_fa=vad_fa,
                                  vad_prec=vad_precision,
                                  vad_recall=vad_recall,
                                  vad_matt=vad_matt,
                                  vad_f1=vad_f1,
                                  vod_miss=vod_miss,
                                  vod_fa=vod_fa,
                                  vod_prec=vod_precision,
                                  vod_recall=vod_recall,
                                  vod_matt=vod_matt,
                                  vod_f1=vod_f1,
                                  count_miss=count_miss,
                                  count_fa=count_fa,
                                  count_prec=count_precision,
                                  count_recall=count_recall,
                                  count_matt=count_matt,
                                  count_f1=count_f1)
                    t.update()

            writer.log_metrics("train_vad", loss, vad_fa, vad_miss, vad_recall,
                               vad_precision, vad_f1, vad_matt, vad_tp, vad_tn,
                               vad_fp, vad_fn, epoch)
            writer.log_metrics("train_vod", loss, vod_fa, vod_miss, vod_recall,
                               vod_precision, vod_f1, vod_matt, vod_tp, vod_tn,
                               vod_fp, vod_fn, epoch)
            writer.log_metrics("train_count", loss, count_fa, count_miss,
                               count_recall, count_precision, count_f1,
                               count_matt, count_tp, count_tn, count_fp,
                               count_fn, epoch)
            # end epoch save model and validate it

            val_loss = validate(hp, model, devloader, writer, epoch)

            if hp.train.save_best == 0:
                save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': epoch,
                        'hp_str': hp_str,
                    }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)

            else:
                if val_loss < best_loss:  # save only when best
                    best_loss = val_loss
                    save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch)
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'step': epoch,
                            'hp_str': hp_str,
                        }, save_path)
                logger.info("Saved checkpoint to: %s" % save_path)

        return best_loss

    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
Beispiel #23
0
class nnUNetTrainerV2_DP(nnUNetTrainerV2):
    def __init__(self,
                 plans_file,
                 fold,
                 output_folder=None,
                 dataset_directory=None,
                 batch_dice=True,
                 stage=None,
                 unpack_data=True,
                 deterministic=True,
                 num_gpus=1,
                 distribute_batch_size=False,
                 fp16=False):
        super(nnUNetTrainerV2_DP,
              self).__init__(plans_file, fold, output_folder,
                             dataset_directory, batch_dice, stage, unpack_data,
                             deterministic, fp16)
        self.init_args = (plans_file, fold, output_folder, dataset_directory,
                          batch_dice, stage, unpack_data, deterministic,
                          num_gpus, distribute_batch_size, fp16)
        self.num_gpus = num_gpus
        self.distribute_batch_size = distribute_batch_size
        self.dice_smooth = 1e-5
        self.dice_do_BG = False
        self.loss = None
        self.loss_weights = None

    def setup_DA_params(self):
        super(nnUNetTrainerV2_DP, self).setup_DA_params()
        self.data_aug_params['num_threads'] = 8 * self.num_gpus

    def process_plans(self, plans):
        super(nnUNetTrainerV2_DP, self).process_plans(plans)
        if not self.distribute_batch_size:
            self.batch_size = self.num_gpus * self.plans['plans_per_stage'][
                self.stage]['batch_size']
        else:
            if self.batch_size < self.num_gpus:
                print(
                    "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well"
                )
            elif self.batch_size % self.num_gpus != 0:
                print(
                    "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well"
                )

    def initialize(self, training=True, force_load_plans=False):
        """
        - replaced get_default_augmentation with get_moreDA_augmentation
        - only run this code once
        - loss function wrapper for deep supervision

        :param training:
        :param force_load_plans:
        :return:
        """
        if not self.was_initialized:
            os.makedirs(self.output_folder, exist_ok=True)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            ################# Here configure the loss for deep supervision ############
            net_numpool = len(self.net_num_pool_op_kernel_sizes)
            weights = np.array([1 / (2**i) for i in range(net_numpool)])
            mask = np.array([
                True if i < net_numpool - 1 else False
                for i in range(net_numpool)
            ])
            weights[~mask] = 0
            weights = weights / weights.sum()
            self.loss_weights = weights
            ################# END ###################

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    print("unpacking dataset")
                    unpack_dataset(self.folder_with_preprocessed_data)
                    print("done")
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    pin_memory=self.pin_memory)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()

            assert isinstance(self.network,
                              (SegmentationNetwork, DataParallel))
        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True

    def initialize_network(self):
        """
        replace genericUNet with the implementation of above for super speeds
        """
        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet_DP(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op,
            dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.SGD(self.network.parameters(),
                                         self.initial_lr,
                                         weight_decay=self.weight_decay,
                                         momentum=0.99,
                                         nesterov=True)
        self.lr_scheduler = None

    def run_training(self):
        self.maybe_update_lr(self.epoch)

        # amp must be initialized before DP

        ds = self.network.do_ds
        self.network.do_ds = True
        self.network = DataParallel(
            self.network,
            tuple(range(self.num_gpus)),
        )
        ret = nnUNetTrainer.run_training(self)
        self.network = self.network.module
        self.network.do_ds = ds
        return ret

    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                ret = self.network(data,
                                   target,
                                   return_hard_tp_fp_fn=run_online_evaluation)
                if run_online_evaluation:
                    ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                    self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
                else:
                    ces, tps, fps, fns = ret
                del data, target
                l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            ret = self.network(data,
                               target,
                               return_hard_tp_fp_fn=run_online_evaluation)
            if run_online_evaluation:
                ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
            else:
                ces, tps, fps, fns = ret
            del data, target
            l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        return l.detach().cpu().numpy()

    def run_online_evaluation(self, tp_hard, fp_hard, fn_hard):
        tp_hard = tp_hard.detach().cpu().numpy().mean(0)
        fp_hard = fp_hard.detach().cpu().numpy().mean(0)
        fn_hard = fn_hard.detach().cpu().numpy().mean(0)
        self.online_eval_foreground_dc.append(
            list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
        self.online_eval_tp.append(list(tp_hard))
        self.online_eval_fp.append(list(fp_hard))
        self.online_eval_fn.append(list(fn_hard))

    def compute_loss(self, ces, tps, fps, fns):
        # we now need to effectively reimplement the loss
        loss = None
        for i in range(len(ces)):
            if not self.dice_do_BG:
                tp = tps[i][:, 1:]
                fp = fps[i][:, 1:]
                fn = fns[i][:, 1:]
            else:
                tp = tps[i]
                fp = fps[i]
                fn = fns[i]

            if self.batch_dice:
                tp = tp.sum(0)
                fp = fp.sum(0)
                fn = fn.sum(0)
            else:
                pass

            nominator = 2 * tp + self.dice_smooth
            denominator = 2 * tp + fp + fn + self.dice_smooth

            dice_loss = (-nominator / denominator).mean()
            if loss is None:
                loss = self.loss_weights[i] * (ces[i].mean() + dice_loss)
            else:
                loss += self.loss_weights[i] * (ces[i].mean() + dice_loss)
        ###########
        return loss
Beispiel #24
0
 mile_stone = 1
 track_step = 2
 batch_size = 8
 init_lr = 0.01
 devices = ['cuda:0']  # , 'cuda:1']  # , 'cuda:2', 'cuda:3']
 model = THN2(state_len=16)
 print(model)
 model.init_weights()
 start_epoch = 0
 opt_state = None
 # start_epoch, model, opt_state = load_model(
 #     '/home/toka/code/FairMOT/chkpts/thn2/epoch_7_26.pth', model)
 device = 'cuda:0'
 model.to_device(device)
 model = model.to(device)
 model = DataParallel(model, devices)
 ps = SRealTracks(
     [os.path.join(p4_root, '%s_dense.p4data' % seq) for seq in seqs],
     '/home/toka/data/P4data/seqinfos.yaml',
     [os.path.join(gt_root, seq, 'gt', 'gt.txt') for seq in seqs],
     clip_len=epoch_outer * track_step)
 dl = DataLoader(ps,
                 batch_size=batch_size,
                 shuffle=True,
                 num_workers=8,
                 drop_last=True)
 base_parameters = (y for x, y in model.named_parameters())
 opt = optim.SGD(base_parameters, lr=init_lr, weight_decay=1e-4)
 if opt_state is not None:
     opt.load_state_dict(opt_state)
 lr_scheduler = lr_scd.MultiStepLR(opt, [50, 600], gamma=0.1)
Beispiel #25
0
# SMPl mesh
vertex_num = 6890
smpl_layer = SMPL_Layer(gender='neutral',
                        model_root=cfg.smpl_path +
                        '/smplpytorch/native/models')
face = smpl_layer.th_faces.numpy()
joint_regressor = smpl_layer.th_J_regressor.numpy()
root_joint_idx = 0

# snapshot load
model_path = './snapshot_%d.pth.tar' % int(args.test_epoch)
assert osp.exists(model_path), 'Cannot find model at ' + model_path
print('Load checkpoint from {}'.format(model_path))
model = get_model(vertex_num, joint_num, 'test')

model = DataParallel(model).cpu()
ckpt = torch.load(model_path)
model.load_state_dict(ckpt['network'], strict=False)
model.eval()

# prepare input image
transform = transforms.ToTensor()
img_path = 'input.jpg'
original_img = cv2.imread(img_path)
original_img_height, original_img_width = original_img.shape[:2]

# prepare bbox
bbox = [139.41, 102.25, 222.39, 241.57]  # xmin, ymin, width, height
bbox = process_bbox(bbox, original_img_width, original_img_height)
img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0,
                                                       0.0, False,
Beispiel #26
0
def neuralwarp_train(**kwargs):
    # 多尺度图片训练 396+
    print(kwargs)
    #print("Mask == 1")

    with open(kwargs['params']) as f:
        params = json.load(f)
    if kwargs['manner'] == 'train':
        params['is_train'] = True
    else:
        params['is_train'] = False
    params['batch_size'] = kwargs['batch_size']
    if torch.cuda.device_count() > 1:
        print("-------------------Parallel_GPU_Train--------------------------")
        parallel = True
    else:
        print("------------------Single_GPU_Train----------------------")
        parallel = False
    opt.feature = 'cqt'
    opt.notes = 'SoftDTW'
    opt.model = 'SoftDTW'
    opt.batch_size = 'batch_size'

    os.environ["CUDA_VISIBLE_DEVICES"] = str(kwargs["Device"])
    opt.Device=kwargs["Device"]
    #device_ids = [2]
    opt._parse(kwargs)

    model = getattr(models, opt.model)(params)

    p = 'check_points/' + model.model_name + opt.notes
    #f = os.path.join(p, "0620_07:05:30.pth")#使用Neural_dtw目前最优 0620_07:05:30.pth cover80 map:0.705113267654046 0.08125 7.96875
    #f = os.path.join(p, "0620_17:37:35.pth")
    #f = os.path.join(p, "0621_22:42:59.pth")#NeuralDTW_Milti_Metix_res 0622_16:33:07.pth 0621_22:42:59.pth
    #f = os.path.join(p, "0628_17:00:52.pth")#0628_17:00:52.pth  FCN
    #f = os.path.join(p,"0623_16:01:05.pth") #3seq
    #f = os.path.join(p,"0630_07:59:56.pth")#VGG11 0630_01:10:15.pth 0630_07:59:56.pth
    if  kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation_SPP':
        f = os.path.join(p,"0704_19:58:25.pth")
    elif kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation_SPP2':
        f = os.path.join(p,"0709_00:31:23.pth")
    elif kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation':
        f = os.path.join(p,"0704_06:40:41.pth")
    opt.load_model_path = f
    if kwargs['model'] != 'NeuralDTW' and kwargs['manner'] != 'train':
        if opt.load_latest is True:
            model.load_latest(opt.notes)
        elif opt.load_model_path:
            print("load_model:",opt.load_model_path)
            model.load(opt.load_model_path)
    
    if parallel == True:
        model = DataParallel(model)
    model.to(opt.device)
    torch.multiprocessing.set_sharing_strategy('file_system')
    # step2: data
    out_length =400
    if kwargs['model'] == 'NeuralDTW_CNN_Mask_300':
        out_length = 300
    if kwargs['model'] == 'NeuralDTW_CNN_Mask_spp':
        train_data0 = triplet_CQT(out_length=200, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
        train_data1 = triplet_CQT(out_length=300, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
        train_data2 = triplet_CQT(out_length=400, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
    else:
        train_data0 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
        train_data1 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
        train_data2 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random'])
    val_data80 = CQT('songs80', out_length=kwargs['test_length'])
    val_data = CQT('songs350', out_length=kwargs['test_length'])
    val_data_marukars = CQT('Mazurkas',out_length=kwargs['test_length'])
    
    train_dataloader0 = DataLoader(train_data0, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
    train_dataloader1 = DataLoader(train_data1, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
    train_dataloader2 = DataLoader(train_data2, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
    val_dataloader80 = DataLoader(val_data80, 1, shuffle=False, num_workers=1)
    val_dataloader = DataLoader(val_data, 1, shuffle=False, num_workers=1)
    val_dataloader_marukars = DataLoader(val_data_marukars,1, shuffle=False, num_workers=1)
    if kwargs['manner'] == 'test':
        # val_slow(model, val_dataloader, style='null')
        val_slow_batch(model,val_dataloader_marukars, batch=100, is_dis=kwargs['zo'])
    elif kwargs['manner'] == 'visualize':
        visualize(model, val_dataloader80)
    elif kwargs['manner'] == 'mul_test':
        p = 'check_points/' + model.model_name + opt.notes
        l = sorted(os.listdir(p))[: 20]
        best_MAP, MAP = 0, 0
        for f in l:
            f = os.path.join(p, f)
            model.load(f)
            model.to(opt.device)
            MAP += val_slow_batch(model, val_dataloader, batch=400, is_dis=kwargs['zo'])
            MAP += val_slow_batch(model, val_dataloader80, batch=400, is_dis=kwargs['zo'])
            if MAP > best_MAP:
                print('--best result--')
                best_MAP = MAP
            MAP = 0
    else:
        # step3: criterion and optimizer
        be = torch.nn.BCELoss()

        lr = opt.lr
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)

        # if parallel is True:
        #     optimizer = torch.optim.Adam(model.module.parameters(), lr=lr, weight_decay=opt.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True, min_lr=5e-6)
        # step4: train
        best_MAP = 0
        for epoch in range(opt.max_epoch):
            running_loss = 0
            num = 0
            for ii, ((a0, p0, n0, la0, lp0, ln0), (a1, p1, n1, la1, lp1, ln1), (a2, p2, n2, la2, lp2, ln2)) in tqdm(
                    enumerate(zip(train_dataloader0, train_dataloader1, train_dataloader2))):
                # for ii, (a2, p2, n2) in tqdm(enumerate(train_dataloader2)):
                for flag in range(3):
                    if flag == 0:
                        a, p, n, la, lp, ln = a0, p0, n0, la0, lp0, ln0
                    elif flag == 1:
                        a, p, n, la, lp, ln = a1, p1, n1, la1, lp1, ln1
                    else:
                        a, p, n, la, lp, ln = a2, p2, n2, la2, lp2, ln2
                    B, _, _, _ = a.shape
                    if kwargs["zo"] == True:
                        target = torch.cat((torch.zeros(B), torch.ones(B))).cuda()
                    else:
                        target = torch.cat((torch.ones(B), torch.zeros(B))).cuda()
                    # train model
                    a = a.requires_grad_().to(opt.device)
                    p = p.requires_grad_().to(opt.device)
                    n = n.requires_grad_().to(opt.device)

                    optimizer.zero_grad()
                    pred = model(a, p, n)
                    pred = pred.squeeze(1)   
                    loss = be(pred, target)
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
                    num += a.shape[0]

                if ii % 5000 == 0:
                    running_loss /= num
                    print("train_loss:",running_loss)
                
                    MAP = 0
                    print("Youtube350:")
                    MAP += val_slow_batch(model, val_dataloader, batch=1    , is_dis=kwargs['zo'])
                    print("CoverSong80:")
                    MAP += val_slow_batch(model, val_dataloader80, batch=1, is_dis=kwargs['zo'])
                    # print("Marukars:")
                    # MAP += val_slow_batch(model, val_dataloader_marukars, batch=100, is_dis=kwargs['zo'])
                    if MAP > best_MAP:
                        best_MAP = MAP
                        print('*****************BEST*****************')
                    if kwargs['save_model'] == True:
                        if parallel:
                            model.module.save(opt.notes)
                        else:
                            model.save(opt.notes)
                    scheduler.step(running_loss)
                    running_loss = 0
                    num = 0
Beispiel #27
0
def execute(args):
    try:
        logger.info('人物深度処理開始: {0}', args.img_dir, decoration=MLogger.DECORATION_BOX)

        if not os.path.exists(args.img_dir):
            logger.error("指定された処理用ディレクトリが存在しません。: {0}", args.img_dir, decoration=MLogger.DECORATION_BOX)
            return False

        parser = get_parser()
        argv = parser.parse_args(args=[])

        if not os.path.exists(argv.model_path):
            logger.error("指定された学習モデルが存在しません。: {0}", argv.model_path, decoration=MLogger.DECORATION_BOX)
            return False

        cudnn.benchmark = True

        # snapshot load
        model = get_pose_net(argv, False)
        model = DataParallel(model).to('cuda')
        ckpt = torch.load(argv.model_path)
        model.load_state_dict(ckpt['network'])
        model.eval()
        focal = [1500, 1500] # x-axis, y-axis

        # prepare input image
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=argv.pixel_mean, std=argv.pixel_std)])

        # 全人物分の順番別フォルダ
        ordered_person_dir_pathes = sorted(glob.glob(os.path.join(args.img_dir, "ordered", "*")), key=sort_by_numeric)

        frame_pattern = re.compile(r'^(frame_(\d+)\.png)')

        for oidx, ordered_person_dir_path in enumerate(ordered_person_dir_pathes):    
            logger.info("【No.{0}】人物深度推定開始", f"{oidx:03}", decoration=MLogger.DECORATION_LINE)

            frame_json_pathes = sorted(glob.glob(os.path.join(ordered_person_dir_path, "frame_*.json")), key=sort_by_numeric)

            for frame_json_path in tqdm(frame_json_pathes, desc=f"No.{oidx:03} ... "):                
                m = frame_pattern.match(os.path.basename(frame_json_path))
                if m:
                    frame_image_name = str(m.groups()[0])
                    fno_name = str(m.groups()[1])
                    
                    # 該当フレームの画像パス
                    frame_image_path = os.path.join(args.img_dir, "frames", fno_name, frame_image_name)

                    if os.path.exists(frame_image_path):

                        frame_joints = {}
                        with open(frame_json_path, 'r') as f:
                            frame_joints = json.load(f)
                        
                        width = int(frame_joints['image']['width'])
                        height = int(frame_joints['image']['height'])

                        original_img = cv2.imread(frame_image_path)

                        bx = float(frame_joints["bbox"]["x"])
                        by = float(frame_joints["bbox"]["y"])
                        bw = float(frame_joints["bbox"]["width"])
                        bh = float(frame_joints["bbox"]["height"])

                        # ROOT_NETで深度推定
                        bbox = process_bbox([bx, by, bw, bh], width, height, argv)
                        img, img2bb_trans = generate_patch_image(original_img, bbox, False, 0.0, argv)
                        img = transform(img).to('cuda')[None,:,:,:]
                        k_value = np.array([math.sqrt(argv.bbox_real[0] * argv.bbox_real[1] * focal[0] * focal[1] / (bbox[2] * bbox[3]))]).astype(np.float32)
                        k_value = torch.FloatTensor([k_value]).to('cuda')[None,:]

                        with torch.no_grad():
                            root_3d = model(img, k_value) # x,y: pixel, z: root-relative depth (mm)

                        img = img[0].to('cpu').numpy()
                        root_3d = root_3d[0].to('cpu').numpy()
                        root_3d[0] = root_3d[0] / argv.output_shape[0] * bbox[2] + bbox[0]
                        root_3d[1] = root_3d[1] / argv.output_shape[1] * bbox[3] + bbox[1]

                        frame_joints["root"] = {"x": float(root_3d[0]), "y": float(root_3d[1]), "z": float(root_3d[2]), \
                                                "input": {"x": argv.input_shape[0], "y": argv.input_shape[1]}, "output": {"x": argv.output_shape[0], "y": argv.output_shape[1]}, \
                                                "focal": {"x": focal[0], "y": focal[1]}}

                        with open(frame_json_path, 'w') as f:
                            json.dump(frame_joints, f, indent=4)

        logger.info('人物深度処理終了: {0}', args.img_dir, decoration=MLogger.DECORATION_BOX)

        return True
    except Exception as e:
        logger.critical("人物深度で予期せぬエラーが発生しました。", e, decoration=MLogger.DECORATION_BOX)
        return False
Beispiel #28
0
class RootNet(object):
    def __init__(self, weightsPath, principal_points=None, focal=(1500, 1500)):
        """

        :param weightsPath:
        :param principal_points:
        :param focal:
        """

        self.focal = focal
        self.principal_points = principal_points

        self.net = get_pose_net(cfg, False)
        self.net = DataParallel(self.net).cuda()
        weigths = torch.load(weightsPath)
        self.net.load_state_dict(weigths['network'])
        self.net.eval()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
        ])

    def estimate(self, bboxes, image, tracking=False):
        """

        :param bboxes:
        :param image:
        :return:
        """
        if self.principal_points is None:
            self.principal_points = [image.shape[1] / 2, image.shape[0] / 2]

        output = []
        for bbox in bboxes:
            bbox_xywh = convertToXYWH(bbox[0], bbox[1], bbox[2], bbox[3])
            bbox_root = process_bbox(bbox_xywh, image.shape[1], image.shape[0])
            img, img2bb_trans = generate_patch_image(image, bbox_root, False,
                                                     0.0)
            img = self.transform(img).cuda()[None, :, :, :]
            k_value = np.array([
                math.sqrt(cfg.bbox_real[0] * cfg.bbox_real[1] * self.focal[0] *
                          self.focal[1] / (bbox_root[2] * bbox_root[3]))
            ]).astype(np.float32)
            k_value = torch.FloatTensor([k_value]).cuda()[None, :]

            # forward
            with torch.no_grad():
                root_3d = self.net(
                    img, k_value)  # x,y: pixel, z: root-relative depth (mm)
            root_3d = root_3d[0].cpu().numpy()

            # inverse affine transform (restore the crop and resize)
            root_3d[0] = root_3d[0] / cfg.output_shape[1] * cfg.input_shape[1]
            root_3d[1] = root_3d[1] / cfg.output_shape[0] * cfg.input_shape[0]
            root_3d_xy1 = np.concatenate(
                (root_3d[:2], np.ones_like(root_3d[:1])))
            img2bb_trans_001 = np.concatenate(
                (img2bb_trans, np.array([0, 0, 1]).reshape(1, 3)))
            root_3d[:2] = np.dot(np.linalg.inv(img2bb_trans_001),
                                 root_3d_xy1)[:2]
            # get 3D coordinates for bbox
            root_3d = pixel2cam(root_3d[None, :], self.focal,
                                self.principal_points)
            if tracking:
                pid = bbox[-1]
                output.append([bbox, root_3d, pid])
            else:
                output.append([bbox, root_3d])

        return output
Beispiel #29
0
def stage1_train(args):
    logger = init_logger(args)
    if args.summary:
        summary_writer = SummaryWriter(args.s1_summary_path)
    dataset = Birds(args.data_dir, split='train', im_size=64)
    dataloader = DataLoader(dataset, batch_size=args.s1_batch_size, shuffle=True, num_workers=8, drop_last=True)
    generator = Stage1Generator(args.txt_embedding_dim, args.c_dim, args.z_dim, args.gf_dim).cuda()
    print('generator={}'.format(generator))
    discriminator = Stage1Discriminator(args.df_dim, args.c_dim).cuda()
    print('discriminator={}'.format(discriminator))
    device_ids = list(range(torch.cuda.device_count()))
    generator = DataParallel(generator, device_ids)
    discriminator = DataParallel(discriminator, device_ids)
    g_parameters = list(filter(lambda f: f.requires_grad, generator.parameters()))
    d_parameters = list(filter(lambda f: f.requires_grad, discriminator.parameters()))
    g_optimizer = torch.optim.Adam(g_parameters, args.lr, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(d_parameters, args.lr, betas=(0.5, 0.999))
    r_labels = torch.ones((args.s1_batch_size,), device='cuda:0')
    f_labels = torch.zeros((args.s1_batch_size,), device='cuda:0')
    criterion = nn.BCELoss()
    cur_lr = args.lr
    for epoch in range(args.total_epoch):
        for idx, (r_imgs, txt_embeddings) in enumerate(dataloader):
            r_imgs = r_imgs.cuda()
            txt_embeddings = txt_embeddings.cuda()
            # discriminator
            noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_()
            x, mu, logvar = generator(txt_embeddings, noise)
            d_loss, r_loss, w_loss, f_loss = discriminator_loss(discriminator, r_imgs, x.detach(), mu.detach(), r_labels, f_labels, criterion)
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            # generator
            noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_()
            x, mu, logvar = generator(txt_embeddings, noise)
            logits = discriminator(mu.detach(), x)
            g_loss = criterion(logits, r_labels)
            kl_loss_ = kl_loss(mu, logvar)
            g_loss += kl_loss_
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()
            if args.summary and idx % args.summary_iters == 0 and idx > 0:
                summary_writer.add_scalar('d_loss', g_loss.item())
                summary_writer.add_scalar('r_loss', r_loss.item())
                summary_writer.add_scalar('w_loss', w_loss.item())
                summary_writer.add_scalar('f_loss', f_loss.item())
                summary_writer.add_scalar('g_loss', g_loss.item())
                summary_writer.add_scalar('kl_loss', kl_loss.item())
        if epoch % args.lr_decay_every_epoch == 0 and epoch > 0:
            logger.info(f'lr decay: {cur_lr}')
            cur_lr *= args.lr_decay_ratio
            g_optimizer = torch.optim.Adam(g_parameters, cur_lr, betas=(0.5, 0.999))
            d_optimizer = torch.optim.Adam(d_parameters, cur_lr, betas=(0.5, 0.999))
        if epoch % args.display_epoch == 0 and epoch > 0:
            logger.info(f'epoch:{epoch}, lr={cur_lr}, d_loss={d_loss}, r_loss={r_loss}, w_loss={w_loss}, f_loss={f_loss}, g_loss={g_loss}, kl_loss={kl_loss_}')
        if epoch % args.checkpoint_epoch == 0 and epoch > 0:
            if not os.path.isdir(args.s1_checkpoint_dir):
                os.makedirs(args.s1_checkpoint_dir)
            logger.info(f'saving checkpoints_{epoch}')
            torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'generator_epoch_{epoch}.pth'))
            torch.save(discriminator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))
    torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'generator.pth'))
    torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'discriminator.pth'))
    if args.summary:
        summary_writer.close()
Beispiel #30
0
def load_model(path='rootnet/rootnet_snapshot_18.pth.tar'):
    model = DataParallel(get_pose_net()).cuda()
    model.load_state_dict(torch.load(path)['network'])
    model.eval()
    return model