コード例 #1
0
ファイル: runner_helper.py プロジェクト: charlotte12l/torchcv
    def _make_parallel(runner, net):
        if runner.configer.get('network.distributed', default=False):
            #print('n1')
            from apex.parallel import DistributedDataParallel
            #print('n2')
            if runner.configer.get('network.syncbn', default=False):
                Log.info('Converting syncbn model...')
                from apex.parallel import convert_syncbn_model
                net = convert_syncbn_model(net)

            torch.cuda.set_device(runner.configer.get('local_rank'))
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')
            net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
            return net

        net = net.to(
            torch.device(
                'cpu' if runner.configer.get('gpu') is None else 'cuda'))
        if len(runner.configer.get('gpu')) > 1:
            from exts.tools.parallel.data_parallel import ParallelModel
            return ParallelModel(net,
                                 gather_=runner.configer.get(
                                     'network', 'gather'))

        return net
コード例 #2
0
ファイル: runner_helper.py プロジェクト: wxwoods/torchcv
 def _make_parallel(runner, net):
     if runner.configer.get('network.distributed', default=False):
         from apex.parallel import DistributedDataParallel
         torch.cuda.set_device(runner.configer.get('local_rank'))
         torch.distributed.init_process_group(backend='nccl',
                                              init_method='env://')
         net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
         return net
     else:
         net = net.to(
             torch.device(
                 'cpu' if runner.configer.get('gpu') is None else 'cuda'))
         from exts.tools.parallel.data_parallel import ParallelModel
         return ParallelModel(net,
                              gather_=runner.configer.get(
                                  'network', 'gather'))
    def make_parallel(runner, net, optimizer):
        if runner.configer.get('distributed', default=False):
            from apex.parallel import DistributedDataParallel
            if runner.configer.get('network.syncbn', default=False):
                Log.info('Converting syncbn model...')
                from apex.parallel import convert_syncbn_model
                net = convert_syncbn_model(net)
            torch.cuda.set_device(runner.configer.get('local_rank'))
            torch.distributed.init_process_group(backend='nccl', init_method='env://')
            if runner.configer.get('dtype') == 'fp16':
                from apex import amp
                net, optimizer = amp.initialize(net.cuda(), optimizer, opt_level="O1")
                net = DistributedDataParallel(net, delay_allreduce=True)
            else:
                assert runner.configer.get('dtype') == 'none'
                net = DistributedDataParallel(net.cuda(), delay_allreduce=True)
            return net, optimizer
        net = net.to(torch.device('cpu' if runner.configer.get('gpu') is None else 'cuda'))
        if len(runner.configer.get('gpu')) > 1:
            from lib.utils.parallel.data_parallel import DataParallelModel
            return DataParallelModel(net, gather_=runner.configer.get('network', 'gather')), optimizer

        return net, optimizer
コード例 #4
0
ファイル: nnUNetTrainerV2_DDP.py プロジェクト: zz10001/nnUNet
class nnUNetTrainerV2_DDP(nnUNetTrainerV2):
    def __init__(self,
                 plans_file,
                 fold,
                 local_rank,
                 output_folder=None,
                 dataset_directory=None,
                 batch_dice=True,
                 stage=None,
                 unpack_data=True,
                 deterministic=True,
                 distribute_batch_size=False,
                 fp16=False):
        super().__init__(plans_file, fold, output_folder, dataset_directory,
                         batch_dice, stage, unpack_data, deterministic, fp16)
        self.init_args = (plans_file, fold, local_rank, output_folder,
                          dataset_directory, batch_dice, stage, unpack_data,
                          deterministic, distribute_batch_size, fp16)
        self.distribute_batch_size = distribute_batch_size
        np.random.seed(local_rank)
        torch.manual_seed(local_rank)
        torch.cuda.manual_seed_all(local_rank)
        self.local_rank = local_rank

        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')

        self.val_loss_ma_alpha = 0.95
        self.val_loss_MA = None

        self.loss = None
        self.ce_loss = CrossentropyND()

        self.global_batch_size = None  # we need to know this to properly steer oversample

    def set_batch_size_and_oversample(self):
        batch_sizes = []
        oversample_percents = []

        world_size = dist.get_world_size()
        my_rank = dist.get_rank()

        if self.distribute_batch_size:
            self.global_batch_size = self.batch_size
        else:
            self.global_batch_size = self.batch_size * world_size

        batch_size_per_GPU = np.ceil(self.batch_size / world_size).astype(int)

        for rank in range(world_size):
            if self.distribute_batch_size:
                if (rank + 1) * batch_size_per_GPU > self.batch_size:
                    batch_size = batch_size_per_GPU - (
                        (rank + 1) * batch_size_per_GPU - self.batch_size)
                else:
                    batch_size = batch_size_per_GPU
            else:
                batch_size = self.batch_size

            batch_sizes.append(batch_size)

            sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(
                batch_sizes[:-1])
            sample_id_high = np.sum(batch_sizes)

            if sample_id_high / self.global_batch_size < (
                    1 - self.oversample_foreground_percent):
                oversample_percents.append(0.0)
            elif sample_id_low / self.global_batch_size > (
                    1 - self.oversample_foreground_percent):
                oversample_percents.append(1.0)
            else:
                percent_covered_by_this_rank = sample_id_high / self.global_batch_size - sample_id_low / self.global_batch_size
                oversample_percent_here = 1 - (
                    ((1 - self.oversample_foreground_percent) - sample_id_low /
                     self.global_batch_size) / percent_covered_by_this_rank)
                oversample_percents.append(oversample_percent_here)

        print("worker", my_rank, "oversample", oversample_percents[my_rank])
        print("worker", my_rank, "batch_size", batch_sizes[my_rank])

        self.batch_size = batch_sizes[my_rank]
        self.oversample_foreground_percent = oversample_percents[my_rank]

    def save_checkpoint(self, fname, save_optimizer=True):
        if self.local_rank == 0:
            super().save_checkpoint(fname, save_optimizer)

    def plot_progress(self):
        if self.local_rank == 0:
            super().plot_progress()

    def print_to_log_file(self, *args, also_print_to_console=True):
        if self.local_rank == 0:
            super().print_to_log_file(
                *args, also_print_to_console=also_print_to_console)

    def initialize_network(self):
        """
        This is specific to the U-Net and must be adapted for other network architectures
        :return:
        """
        self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
        self.print_to_log_file(self.net_conv_kernel_sizes)

        if self.threeD:
            conv_op = nn.Conv3d
            dropout_op = nn.Dropout3d
            norm_op = nn.InstanceNorm3d

        else:
            conv_op = nn.Conv2d
            dropout_op = nn.Dropout2d
            norm_op = nn.InstanceNorm2d

        norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        dropout_op_kwargs = {'p': 0, 'inplace': True}
        net_nonlin = nn.LeakyReLU
        net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        self.network = Generic_UNet(
            self.num_input_channels, self.base_num_features, self.num_classes,
            len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2,
            conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
            net_nonlin, net_nonlin_kwargs, True, False, lambda x: x,
            InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes,
            self.net_conv_kernel_sizes, False, True, True)
        self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper

    def process_plans(self, plans):
        super().process_plans(plans)
        self.set_batch_size_and_oversample()

    def initialize(self, training=True, force_load_plans=False):
        """
        For prediction of test cases just set training=False, this will prevent loading of training data and
        training batchgenerator initialization
        :param training:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()

            self.process_plans(self.plans)

            self.setup_DA_params()

            self.folder_with_preprocessed_data = join(
                self.dataset_directory,
                self.plans['data_identifier'] + "_stage%d" % self.stage)
            if training:
                self.dl_tr, self.dl_val = self.get_basic_generators()
                if self.unpack_data:
                    if self.local_rank == 0:
                        print("unpacking dataset")
                        unpack_dataset(self.folder_with_preprocessed_data)
                        print("done")
                    else:
                        # we need to wait until worker 0 has finished unpacking
                        npz_files = subfiles(
                            self.folder_with_preprocessed_data,
                            suffix=".npz",
                            join=False)
                        case_ids = [i[:-4] for i in npz_files]
                        all_present = all([
                            isfile(
                                join(self.folder_with_preprocessed_data,
                                     i + ".npy")) for i in case_ids
                        ])
                        while not all_present:
                            print("worker", self.local_rank,
                                  "is waiting for unpacking")
                            sleep(3)
                            all_present = all([
                                isfile(
                                    join(self.folder_with_preprocessed_data,
                                         i + ".npy")) for i in case_ids
                            ])
                        # there is some slight chance that there may arise some error because dataloader are loading a file
                        # that is still being written by worker 0. We ignore this for now an address it only if it becomes
                        # relevant
                        # (this can occur because while worker 0 writes the file is technically present so the other workers
                        # will proceed and eventually try to read it)
                else:
                    print(
                        "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
                        "will wait all winter for your model to finish!")

                # setting weights for deep supervision losses
                net_numpool = len(self.net_num_pool_op_kernel_sizes)

                # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
                # this gives higher resolution outputs more weight in the loss
                weights = np.array([1 / (2**i) for i in range(net_numpool)])

                # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
                mask = np.array([
                    True if i < net_numpool - 1 else False
                    for i in range(net_numpool)
                ])
                weights[~mask] = 0
                weights = weights / weights.sum()
                self.ds_loss_weights = weights

                seeds_train = np.random.random_integers(
                    0, 99999, self.data_aug_params.get('num_threads'))
                seeds_val = np.random.random_integers(
                    0, 99999,
                    max(self.data_aug_params.get('num_threads') // 2, 1))
                print("seeds train", seeds_train)
                print("seeds_val", seeds_val)
                self.tr_gen, self.val_gen = get_moreDA_augmentation(
                    self.dl_tr,
                    self.dl_val,
                    self.data_aug_params['patch_size_for_spatialtransform'],
                    self.data_aug_params,
                    deep_supervision_scales=self.deep_supervision_scales,
                    seeds_train=seeds_train,
                    seeds_val=seeds_val)
                self.print_to_log_file("TRAINING KEYS:\n %s" %
                                       (str(self.dataset_tr.keys())),
                                       also_print_to_console=False)
                self.print_to_log_file("VALIDATION KEYS:\n %s" %
                                       (str(self.dataset_val.keys())),
                                       also_print_to_console=False)
            else:
                pass

            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            self._maybe_init_amp()
            self.network = DDP(self.network)

        else:
            self.print_to_log_file(
                'self.was_initialized is True, not running self.initialize again'
            )
        self.was_initialized = True

    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        data = to_cuda(data, gpu_id=None)
        target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax,
                                            target[i],
                                            axes,
                                            mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                num_classes = output[0].shape[1]
                output_seg = output[0].argmax(1)
                target = target[0][:, 0]
                axes = tuple(range(1, len(target.shape)))
                tp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fn_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                for c in range(1, num_classes):
                    tp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target == c).float(),
                        axes=axes)
                    fp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target != c).float(),
                        axes=axes)
                    fn_hard[:, c - 1] = sum_tensor(
                        (output_seg != c).float() * (target == c).float(),
                        axes=axes)

                # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
                #                                         axes, None)
                # print_if_rank0("before allgather", tp_hard.shape)
                tp_hard = tp_hard.sum(0, keepdim=False)[None]
                fp_hard = fp_hard.sum(0, keepdim=False)[None]
                fn_hard = fn_hard.sum(0, keepdim=False)[None]

                tp_hard = awesome_allgather_function.apply(tp_hard)
                fp_hard = awesome_allgather_function.apply(fp_hard)
                fn_hard = awesome_allgather_function.apply(fn_hard)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(
                    tp_hard.detach().cpu().numpy().sum(0),
                    fp_hard.detach().cpu().numpy().sum(0),
                    fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None:
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()

    def run_online_evaluation(self, tp, fp, fn):
        self.online_eval_foreground_dc.append(
            list((2 * tp) / (2 * tp + fp + fn + 1e-8)))
        self.online_eval_tp.append(list(tp))
        self.online_eval_fp.append(list(fp))
        self.online_eval_fn.append(list(fn))

    def run_training(self):
        """
        if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
        continued epoch with self.initial_lr

        we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
        :return:
        """
        self.maybe_update_lr(
            self.epoch
        )  # if we dont overwrite epoch then self.epoch+1 is used which is not what we
        # want at the start of the training
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = True
        ret = nnUNetTrainer.run_training(self)
        net.do_ds = ds
        return ret

    def validate(self,
                 do_mirroring: bool = True,
                 use_train_mode: bool = False,
                 tiled: bool = True,
                 step: int = 2,
                 save_softmax: bool = True,
                 use_gaussian: bool = True,
                 overwrite: bool = True,
                 validation_folder_name: str = 'validation_raw',
                 debug: bool = False,
                 all_in_gpu: bool = False,
                 force_separate_z: bool = None,
                 interpolation_order: int = 3,
                 interpolation_order_z=0):
        if self.local_rank == 0:
            if isinstance(self.network, DDP):
                net = self.network.module
            else:
                net = self.network
            ds = net.do_ds
            net.do_ds = False
            ret = nnUNetTrainer.validate(
                self,
                do_mirroring,
                use_train_mode,
                tiled,
                step,
                save_softmax,
                use_gaussian,
                overwrite,
                validation_folder_name,
                debug,
                all_in_gpu,
                force_separate_z=force_separate_z,
                interpolation_order=interpolation_order,
                interpolation_order_z=interpolation_order_z)
            net.do_ds = ds
            return ret

    def predict_preprocessed_data_return_softmax(self,
                                                 data,
                                                 do_mirroring,
                                                 num_repeats,
                                                 use_train_mode,
                                                 batch_size,
                                                 mirror_axes,
                                                 tiled,
                                                 tile_in_z,
                                                 step,
                                                 min_size,
                                                 use_gaussian,
                                                 all_in_gpu=False):
        """
        Don't use this. If you need softmax output, use preprocess_predict_nifti and set softmax_output_file.
        :param data:
        :param do_mirroring:
        :param num_repeats:
        :param use_train_mode:
        :param batch_size:
        :param mirror_axes:
        :param tiled:
        :param tile_in_z:
        :param step:
        :param min_size:
        :param use_gaussian:
        :param use_temporal:
        :return:
        """
        valid = list((SegmentationNetwork, nn.DataParallel, DDP))
        assert isinstance(self.network, tuple(valid))
        if isinstance(self.network, DDP):
            net = self.network.module
        else:
            net = self.network
        ds = net.do_ds
        net.do_ds = False
        ret = net.predict_3D(data,
                             do_mirroring,
                             num_repeats,
                             use_train_mode,
                             batch_size,
                             mirror_axes,
                             tiled,
                             tile_in_z,
                             step,
                             min_size,
                             use_gaussian=use_gaussian,
                             pad_border_mode=self.inference_pad_border_mode,
                             pad_kwargs=self.inference_pad_kwargs,
                             all_in_gpu=all_in_gpu)[2]
        net.do_ds = ds
        return ret

    def load_checkpoint_ram(self, saved_model, train=True):
        """
        used for if the checkpoint is already in ram
        :param saved_model:
        :param train:
        :return:
        """
        if not self.was_initialized:
            self.initialize(train)

        new_state_dict = OrderedDict()
        curr_state_dict_keys = list(self.network.state_dict().keys())
        # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not
        # match. Use heuristic to make it match
        for k, value in saved_model['state_dict'].items():
            key = k
            if key not in curr_state_dict_keys:
                print("duh")
                key = key[7:]
            new_state_dict[key] = value

        # if we are fp16, then we need to reinitialize the network and the optimizer. Otherwise amp will throw an error
        if self.fp16:
            self.network, self.optimizer, self.lr_scheduler = None, None, None
            self.initialize_network()
            self.initialize_optimizer_and_scheduler()
            # we need to reinitialize DDP here
            self.network = DDP(self.network)

        self.network.load_state_dict(new_state_dict)
        self.epoch = saved_model['epoch']
        if train:
            optimizer_state_dict = saved_model['optimizer_state_dict']
            if optimizer_state_dict is not None:
                self.optimizer.load_state_dict(optimizer_state_dict)

            if self.lr_scheduler is not None and hasattr(
                    self.lr_scheduler, 'load_state_dict'
            ) and saved_model['lr_scheduler_state_dict'] is not None:
                self.lr_scheduler.load_state_dict(
                    saved_model['lr_scheduler_state_dict'])

            if issubclass(self.lr_scheduler.__class__, _LRScheduler):
                self.lr_scheduler.step(self.epoch)

        self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = saved_model[
            'plot_stuff']

        # after the training is done, the epoch is incremented one more time in my old code. This results in
        # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
        # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
        if self.epoch != len(self.all_tr_losses):
            self.print_to_log_file(
                "WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
                "due to an old bug and should only appear when you are loading old models. New "
                "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)"
            )
            self.epoch = len(self.all_tr_losses)
            self.all_tr_losses = self.all_tr_losses[:self.epoch]
            self.all_val_losses = self.all_val_losses[:self.epoch]
            self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.
                                                                      epoch]
            self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]

        self.amp_initialized = False
        self._maybe_init_amp()
コード例 #5
0
def main(model_name, 
         mode,
         root,
         val_split,
         ckpt,
         batch_per_gpu):
    num_gpus = MPI.COMM_WORLD.Get_size()
    distributed = False
    if num_gpus > 1:
        distributed = True

    local_rank = MPI.COMM_WORLD.Get_rank() % torch.cuda.device_count()

    if distributed:
        torch.cuda.set_device(local_rank)
        host = os.environ["MASTER_ADDR"] if "MASTER_ADDR" in os.environ else "127.0.0.1"
        torch.distributed.init_process_group(
            backend="nccl",
            init_method='tcp://{}:12345'.format(host),
            rank=MPI.COMM_WORLD.Get_rank(),
            world_size=MPI.COMM_WORLD.Get_size()
        )

        synchronize()

    val_dataloader = make_dataloader(root,
                                        val_split, 
                                        mode,
                                        model_name,
                                        seq_len=16, #64, 
                                        overlap=8, #32,
                                        phase='val', 
                                        max_iters=None, 
                                        batch_per_gpu=batch_per_gpu,
                                        num_workers=16, 
                                        shuffle=False, 
                                        distributed=distributed,
                                        with_normal=False)

    if model_name == 'i3d':
        if mode == 'flow':
            model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=2, dropout_keep_prob=0.5)
        else:
            model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=3, dropout_keep_prob=0.5)
        model.replace_logits(val_dataloader.dataset.num_classes)
    elif model_name == 'r3d_18':
        model = r3d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'mc3_18':
        model = mc3_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'r2plus1d_18':
        model = r2plus1d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    elif model_name == 'c3d':
        model = C3D(pretrained=False, num_classes=val_dataloader.dataset.num_classes)
    else:
        raise NameError('unknown model name:{}'.format(model_name))

    # pdb.set_trace()
    for param in model.parameters():
        pass
    
    device = torch.device('cuda')
    model.to(device)
    if distributed:
        model = apex.parallel.convert_syncbn_model(model)
        model = DDP(model.cuda(), delay_allreduce=True)
コード例 #6
0
class Solver(object):

    def __init__(self):
        """
            :param config: easydict
        """
        self.version = __version__
        # logging.info("PyTorch Version {}, Solver Version {}".format(torch.__version__, self.version))
        self.distributed = False
        self.world_size = 1
        self.local_rank = 0
        self.epoch = 0
        self.iteration = 0
        self.config = None
        self.model, self.optimizer, self.lr_policy = None, None, None
        self.step_decay = 1

        if 'WORLD_SIZE' in os.environ:
            self.world_size = int(os.environ['WORLD_SIZE'])
            self.distributed = self.world_size > 1 or torch.cuda.device_count() > 1

        if self.distributed:
            dist.init_process_group(backend="nccl", init_method='env://')
            self.local_rank = dist.get_rank()
            torch.cuda.set_device(self.local_rank)
            logging.info('[distributed mode] world size: {}, local rank: {}.'.format(self.world_size, self.local_rank))
        else:
            logging.info('[Single GPU mode]')

    def build_environ(self):
        if self.config['environ']['deterministic']:
            cudnn.benchmark = False
            cudnn.deterministic = True
            torch.set_printoptions(precision=10)
        else:
            cudnn.benchmark = True

        if self.config['apex']:
            assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

        # set random seed
        torch.manual_seed(self.config['environ']['seed'])
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.config['environ']['seed'])
        np.random.seed(self.config['environ']['seed'])
        random.seed(self.config['environ']['seed'])

    def init_from_scratch(self, config):
        t_start = time.time()
        self.config = config
        self.build_environ()
        # model and optimizer
        self.model = _get_model(self.config)
        model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optimizer = _get_optimizer(config['solver']['optimizer'],
                                        model_params=model_params)

        self.lr_policy = _get_lr_policy(config['solver']['lr_policy'], optimizer=self.optimizer)
        self.step_decay = config['solver']['step_decay']

        if config['model'].get('pretrained_model') is not None:
            logging.info('loadding pretrained model from {}.'.format(config['model']['pretrained_model']))
            load_model(self.model, config['model']['pretrained_model'], distributed=False)

        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info("Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}.".
                         format(self.config['apex']['opt_level'],
                                self.config['apex']['keep_batchnorm_fp32'],
                                self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
                                                        opt_level=self.config['apex']['opt_level'],
                                                        keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                                                        loss_scale=self.config['apex']["loss_scale"])
        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        t_end = time.time()
        logging.info("Init trainer from scratch, Time usage: IO: {}".format(t_end - t_start))

    def init_from_checkpoint(self, continue_state_object):
        t_start = time.time()

        self.config = continue_state_object['config']
        self.build_environ()
        self.model = _get_model(self.config)
        model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        self.optimizer = _get_optimizer(self.config['solver']['optimizer'],
                                        model_params=model_params)
        self.lr_policy = _get_lr_policy(self.config['solver']['lr_policy'], optimizer=self.optimizer)

        load_model(self.model, continue_state_object['model'], distributed=False)
        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info("Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}.".
                         format(self.config['apex']['opt_level'],
                                self.config['apex']['keep_batchnorm_fp32'],
                                self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer,
                                                        opt_level=self.config['apex']['opt_level'],
                                                        keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                                                        loss_scale=self.config['apex']["loss_scale"])
            amp.load_state_dict(continue_state_object['amp'])

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        self.optimizer.load_state_dict(continue_state_object['optimizer'])
        self.lr_policy.load_state_dict(continue_state_object['lr_policy'])

        self.step_decay = self.config['solver']['step_decay']
        self.epoch = continue_state_object['epoch']
        self.iteration = continue_state_object["iteration"]

        del continue_state_object
        t_end = time.time()
        logging.info("Init trainer from checkpoint, Time usage: IO: {}".format(t_end - t_start))

    def step(self, **kwargs):
        """
        :param kwargs:
        :return:
        """
        self.iteration += 1
        loss = self.model(**kwargs)
        loss /= self.step_decay

        # backward
        if self.distributed and self.config['apex']['amp_used']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if self.iteration % self.step_decay == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()

        if self.distributed:
            reduced_loss = reduce_tensor(loss.data, self.world_size)
        else:
            reduced_loss = loss.data
        return reduced_loss

    def step_no_grad(self, **kwargs):
        with torch.no_grad():
            out = self.model(**kwargs)
        return out

    def before_epoch(self, epoch):
        self.iteration = 0
        self.epoch = epoch
        self.model.train()
        self.synchronize()
        torch.cuda.empty_cache()
        self.lr_policy.step(epoch)

    def after_epoch(self, epoch):
        self.model.eval()
        self.synchronize()
        torch.cuda.empty_cache()

    def synchronize(self):
        synchronize()

    def save_checkpoint(self, path):
        if self.local_rank == 0:
            # logging.info("Saving checkpoint to file {}".format(path))
            t_start = time.time()

            state_dict = {}

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in self.model.state_dict().items():
                key = k
                if k.split('.')[0] == 'module':
                    key = k[7:]
                new_state_dict[key] = v

            if self.config['apex']['amp_used']:
                state_dict['amp'] = amp.state_dict()
            state_dict['config'] = self.config
            state_dict['model'] = new_state_dict
            state_dict['optimizer'] = self.optimizer.state_dict()
            state_dict['lr_policy'] = self.lr_policy.state_dict()
            state_dict['epoch'] = self.epoch
            state_dict['iteration'] = self.iteration

            t_iobegin = time.time()
            torch.save(state_dict, path)
            del state_dict
            del new_state_dict
            t_end = time.time()
            logging.info(
                "Save checkpoint to file {}, "
                "Time usage:\n\tprepare snapshot: {}, IO: {}".format(
                    path, t_iobegin - t_start, t_end - t_iobegin))

    def save_images(self, filenames, image):
        raise NotImplementedError

    def copy_config(self, snapshot_dir, config_file):
        ensure_dir(snapshot_dir)
        assert osp.exists(config_file), "config file is not existed."
        new_file_name = osp.join(snapshot_dir, 'config.json')
        shutil.copy(config_file, new_file_name)

    def __enter__(self):
        return self

    def __exit__(self, type, value, tb):
        torch.cuda.empty_cache()
        if type is not None:
            logging.warning(
                "A exception occurred during Engine initialization, "
                "give up pspnet_ade process")
            return False
コード例 #7
0
class RunManager:
    def __init__(self, path, net, run_config: RunConfig, out_log=True):
        self.path = path
        self.net = net
        self.run_config = run_config
        self.out_log = out_log

        self._logs_path, self._save_path = None, None
        self.best_acc = 0
        self.start_epoch = 0
        gpu = self.run_config.local_rank
        torch.cuda.set_device(gpu)

        # initialize model (default)
        self.net.init_model(run_config.model_init, run_config.init_div_groups)

        # net info
        self.net = self.net.cuda()
        if run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.run_config.init_lr = self.run_config.init_lr * float(
            self.run_config.train_batch_size *
            self.run_config.world_size) / 256.
        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # self.net, self.optimizer = amp.initialize(self.net, self.optimizer, opt_level='O1')
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True

    """ save path and log path """

    @property
    def save_path(self):
        if self._save_path is None:
            save_path = os.path.join(self.path, 'checkpoint')
            os.makedirs(save_path, exist_ok=True)
            self._save_path = save_path
        return self._save_path

    @property
    def logs_path(self):
        if self._logs_path is None:
            logs_path = os.path.join(self.path, 'logs')
            os.makedirs(logs_path, exist_ok=True)
            self._logs_path = logs_path
        return self._logs_path

    """ net info """

    def reset_model(self, model, model_origin=None):
        self.net = model
        self.net.init_model(self.run_config.model_init,
                            self.run_config.init_div_groups)
        if model_origin != None:
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' start pruning ' + '-' * 30)
            get_unpruned_weights(self.net, model_origin)
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' end pruning ' + '-' * 30)
        # net info
        self.net = self.net.cuda()
        if self.run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # model, self.optimizer = amp.initialize(model, self.optimizer,
        #                                        opt_level='O2',
        #                                        keep_batchnorm_fp32=True,
        #                                        loss_scale=1.0
        #                                        )
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True
        # if model_origin!=None:
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' start training bn '+'-'*30)
        #     self.train_bn(1)
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' end training bn '+'-'*30)

    # noinspection PyUnresolvedReferences
    def net_flops(self):
        data_shape = [1] + list(self.run_config.data_provider.data_shape)

        net = self.net
        input_var = torch.zeros(data_shape).cuda()
        with torch.no_grad():
            flops = profile_macs(net, input_var)
        return flops

    def print_net_info(self):
        # parameters
        total_params = count_parameters(self.net)
        if self.out_log:
            print('Total training params: %.2fM' % (total_params / 1e6))
        net_info = {
            'param': '%.2fM' % (total_params / 1e6),
        }

        # flops
        flops = self.net_flops()
        if self.out_log:
            print('Total FLOPs: %.1fM' % (flops / 1e6))
        net_info['flops'] = '%.1fM' % (flops / 1e6)

        # config
        if self.out_log:
            print('Net config: ' + str(self.net.config))
        net_info['config'] = str(self.net.config)

        with open('%s/net_info.txt' % self.logs_path, 'w') as fout:
            fout.write(json.dumps(net_info, indent=4) + '\n')

    """ save and load models """

    def save_model(self, checkpoint=None, is_best=False, model_name=None):
        if checkpoint is None:
            checkpoint = {'state_dict': self.net.module.state_dict()}

        if model_name is None:
            model_name = 'checkpoint.pth.tar'

        checkpoint[
            'dataset'] = self.run_config.dataset  # add `dataset` info to the checkpoint
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        model_path = os.path.join(self.save_path, model_name)
        with open(latest_fname, 'w') as fout:
            fout.write(model_path + '\n')
        torch.save(checkpoint, model_path)

        if is_best:
            best_path = os.path.join(self.save_path, 'model_best.pth.tar')
            torch.save({'state_dict': checkpoint['state_dict']}, best_path)

    def load_model(self, model_fname=None):
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        if model_fname is None and os.path.exists(latest_fname):
            with open(latest_fname, 'r') as fin:
                model_fname = fin.readline()
                if model_fname[-1] == '\n':
                    model_fname = model_fname[:-1]
        # noinspection PyBroadException
        try:
            if model_fname is None or not os.path.exists(model_fname):
                model_fname = '%s/checkpoint.pth.tar' % self.save_path
                with open(latest_fname, 'w') as fout:
                    fout.write(model_fname + '\n')
            if self.out_log:
                print("=> loading checkpoint '{}'".format(model_fname))

            if torch.cuda.is_available():
                checkpoint = torch.load(model_fname)
            else:
                checkpoint = torch.load(model_fname, map_location='cpu')

            self.net.module.load_state_dict(checkpoint['state_dict'])
            # set new manual seed
            new_manual_seed = int(time.time())
            torch.manual_seed(new_manual_seed)
            torch.cuda.manual_seed_all(new_manual_seed)
            np.random.seed(new_manual_seed)

            if 'epoch' in checkpoint:
                self.start_epoch = checkpoint['epoch'] + 1
            if 'best_acc' in checkpoint:
                self.best_acc = checkpoint['best_acc']
            if 'optimizer' in checkpoint:
                self.optimizer.load_state_dict(checkpoint['optimizer'])

            if self.out_log:
                print("=> loaded checkpoint '{}'".format(model_fname))
        except Exception:
            if self.out_log:
                print('fail to load checkpoint from %s' % self.save_path)

    def save_config(self, print_info=True):
        """ dump run_config and net_config to the model_folder """
        os.makedirs(self.path, exist_ok=True)
        net_save_path = os.path.join(self.path, 'net.config')
        json.dump(self.net.module.config, open(net_save_path, 'w'), indent=4)
        if print_info:
            print('Network configs dump to %s' % net_save_path)

        run_save_path = os.path.join(self.path, 'run.config')
        json.dump(self.run_config.config, open(run_save_path, 'w'), indent=4)
        if print_info:
            print('Run configs dump to %s' % run_save_path)

    """ train and test """

    def write_log(self, log_str, prefix, should_print=True):
        """ prefix: valid, train, test """
        if prefix in ['valid', 'test']:
            with open(os.path.join(self.logs_path, 'valid_console.txt'),
                      'a') as fout:
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['valid', 'test', 'train']:
            with open(os.path.join(self.logs_path, 'train_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['prune']:
            with open(os.path.join(self.logs_path, 'prune_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if should_print:
            print(log_str)

    def validate(self,
                 is_test=True,
                 net=None,
                 use_train_mode=False,
                 return_top5=False):
        if is_test:
            data_loader = self.run_config.test_loader
        else:
            data_loader = self.run_config.valid_loader

        if net is None:
            net = self.net

        if use_train_mode:
            net.train()
        else:
            net.eval()
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        # noinspection PyUnresolvedReferences
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                # images, labels = data[0].cuda(), data[1].cuda()
                # compute output
                output = net(images)
                loss = self.criterion(output, labels)
                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                reduced_loss = self.reduce_tensor(loss.data)
                acc1 = self.reduce_tensor(acc1)
                acc5 = self.reduce_tensor(acc5)
                losses.update(reduced_loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.run_config.print_frequency == 0 or i + 1 == len(
                        data_loader):
                    if is_test:
                        prefix = 'Test'
                    else:
                        prefix = 'Valid'
                    test_log = prefix + ': [{0}/{1}]\t' \
                                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                        format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
                    if return_top5:
                        test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                            top5=top5)
                    print(test_log)
        self.run_config.valid_loader.reset()
        self.run_config.test_loader.reset()
        if return_top5:
            return losses.avg, top1.avg, top5.avg
        else:
            return losses.avg, top1.avg

    def train_bn(self, epochs=1):
        if self.run_config.local_rank == 0:
            print('training bn')
        for m in self.net.modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                m.running_mean = torch.zeros_like(m.running_mean)
                m.running_var = torch.ones_like(m.running_var)
        self.net.train()
        for i in range(epochs):
            for _, data in enumerate(self.run_config.train_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                output = self.net(images)
                del output, images, labels
        if self.run_config.local_rank == 0:
            print('training bn finished')

    def train_one_epoch(self, adjust_lr_func, train_log_func, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # switch to train mode
        self.net.train()

        end = time.time()
        for i, data in enumerate(self.run_config.train_loader):
            data_time.update(time.time() - end)
            new_lr = adjust_lr_func(i)
            images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                non_blocking=True)
            # compute output
            output = self.net(images)
            if self.run_config.label_smoothing > 0:
                loss = cross_entropy_with_label_smoothing(
                    output, labels, self.run_config.label_smoothing)
            else:
                loss = self.criterion(output, labels)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            reduced_loss = self.reduce_tensor(loss.data)
            acc1 = self.reduce_tensor(acc1)
            acc5 = self.reduce_tensor(acc5)
            losses.update(reduced_loss, images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # compute gradient and do SGD step
            self.net.zero_grad()  # or self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            torch.cuda.synchronize()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i % self.run_config.print_frequency == 0
                    or i + 1 == len(self.run_config.train_loader)
                ) and self.run_config.local_rank == 0:
                batch_log = train_log_func(i, batch_time, data_time, losses,
                                           top1, top5, new_lr)
                self.write_log(batch_log, 'train')
        return top1, top5

    def train(self, print_top5=False):
        def train_log_func(epoch_, i, batch_time, data_time, losses, top1,
                           top5, lr):
            batch_log = 'Train [{0}][{1}/{2}]\t' \
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                        'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                format(epoch_ + 1, i, len(self.run_config.train_loader) - 1,
                       batch_time=batch_time, data_time=data_time, losses=losses, top1=top1)
            if print_top5:
                batch_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                    top5=top5)
            batch_log += '\tlr {lr:.5f}'.format(lr=lr)
            return batch_log

        for epoch in range(self.start_epoch, self.run_config.n_epochs):
            if self.run_config.local_rank == 0:
                print('\n', '-' * 30, 'Train epoch: %d' % (epoch + 1),
                      '-' * 30, '\n')

            end = time.time()
            train_top1, train_top5 = self.train_one_epoch(
                lambda i: self.run_config.adjust_learning_rate(
                    self.optimizer, epoch, i, len(self.run_config.train_loader)
                ), lambda i, batch_time, data_time, losses, top1, top5, new_lr:
                train_log_func(epoch, i, batch_time, data_time, losses, top1,
                               top5, new_lr), epoch)
            time_per_epoch = time.time() - end
            seconds_left = int(
                (self.run_config.n_epochs - epoch - 1) * time_per_epoch)
            if self.run_config.local_rank == 0:
                print('Time per epoch: %s, Est. complete in: %s' %
                      (str(timedelta(seconds=time_per_epoch)),
                       str(timedelta(seconds=seconds_left))))

            if (epoch + 1) % self.run_config.validation_frequency == 0:
                val_loss, val_acc, val_acc5 = self.validate(is_test=False,
                                                            return_top5=True)
                is_best = val_acc > self.best_acc
                self.best_acc = max(self.best_acc, val_acc)
                val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})'. \
                    format(epoch + 1, self.run_config.n_epochs, val_loss, val_acc, self.best_acc)
                if print_top5:
                    val_log += '\ttop-5 acc {0:.3f}\tTrain top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
                        format(val_acc5, top1=train_top1, top5=train_top5)
                else:
                    val_log += '\tTrain top-1 {top1.avg:.3f}'.format(
                        top1=train_top1)
                if self.run_config.local_rank == 0:
                    self.write_log(val_log, 'valid')
            else:
                is_best = False
            if self.run_config.local_rank == 0:
                self.save_model(
                    {
                        'epoch': epoch,
                        'best_acc': self.best_acc,
                        'optimizer': self.optimizer.state_dict(),
                        'state_dict': self.net.state_dict(),
                    },
                    is_best=is_best)
            self.run_config.train_loader.reset()
            self.run_config.valid_loader.reset()
            self.run_config.test_loader.reset()

    def reduce_tensor(self, tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.run_config.world_size
        return rt
コード例 #8
0
ファイル: train_distribute.py プロジェクト: zymale/Fast_Seg
def main():

    # make save dir
    if args.local_rank == 0:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
    # launch the logger
    Log.init(
        log_level=args.log_level,
        log_file=osp.join(args.save_dir, args.log_file),
        log_format=args.log_format,
        rewrite=args.rewrite,
        stdout_level=args.stdout_level
    )
    # RGB or BGR input(RGB input for ImageNet pretrained models while BGR input for caffe pretrained models)
    if args.rgb:
        IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
        IMG_VARS = np.array((0.229, 0.224, 0.225), dtype=np.float32)
    else:
        IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
        IMG_VARS = np.array((1, 1, 1), dtype=np.float32)

    # set models
    import libs.models as models
    deeplab = models.__dict__[args.arch](num_classes=args.num_classes, data_set=args.data_set)
    if args.restore_from is not None:
        saved_state_dict = torch.load(args.restore_from, map_location=torch.device('cpu'))
        new_params = deeplab.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
        Log.info("load pretrined models")
        if deeplab.backbone is not None:
            deeplab.backbone.load_state_dict(new_params, strict=False)
        else:
            deeplab.load_state_dict(new_params, strict=False)
    else:
        Log.info("train from stracth")


    args.world_size = 1

    if 'WORLD_SIZE' in os.environ and args.apex:
        args.apex = int(os.environ['WORLD_SIZE']) > 1
        args.world_size = int(os.environ['WORLD_SIZE'])
        print("Total world size: ", int(os.environ['WORLD_SIZE']))

    if not args.gpu == None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    h, w = args.input_size, args.input_size
    input_size = (h, w)


     # Set the device according to local_rank.
    torch.cuda.set_device(args.local_rank)
    Log.info("Local Rank: {}".format(args.local_rank))
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://')
    # set optimizer
    optimizer = optim.SGD(
        [{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate}],
        lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # set on cuda
    deeplab.cuda()

    # models transformation
    model = DistributedDataParallel(deeplab)
    model = apex.parallel.convert_syncbn_model(model)
    model.train()
    model.float()
    model.cuda()

    # set loss function
    if args.ohem:
        criterion = CriterionOhemDSN(thresh=args.ohem_thres, min_kept=args.ohem_keep)  # OHEM CrossEntrop
        if "ic" in args.arch:
            criterion = CriterionICNet(thresh=args.ohem_thres, min_kept=args.ohem_keep)
        if "dfa" in args.arch:
            criterion = CriterionDFANet(thresh=args.ohem_thres, min_kept=args.ohem_keep)
    else:
        criterion = CriterionDSN()  # CrossEntropy
    criterion.cuda()

    cudnn.benchmark = True

    if args.world_size == 1:
        print(model)

    # this is a little different from mul-gpu traning setting in distributed training
    # because each trainloader is a process that sample from the dataset class.
    batch_size = args.gpu_num * args.batch_size_per_gpu
    max_iters = args.num_steps * batch_size / args.gpu_num
    # set data loader
    data_set = Cityscapes(args.data_dir, args.data_list, max_iters=max_iters, crop_size=input_size,
                  scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN,vars=IMG_VARS, RGB= args.rgb)

    trainloader = data.DataLoader(
        data_set,
        batch_size=args.batch_size_per_gpu, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    print("trainloader", len(trainloader))

    torch.cuda.empty_cache()

    # start training:
    for i_iter, batch in enumerate(trainloader):
        images, labels = batch
        images = images.cuda()
        labels = labels.long().cuda()
        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, args, i_iter, len(trainloader))
        preds = model(images)

        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        reduce_loss = all_reduce_tensor(loss,
                                        world_size=args.gpu_num)
        if args.local_rank == 0:
            Log.info('iter = {} of {} completed, lr={}, loss = {}'.format(i_iter,
                                                                      len(trainloader), lr, reduce_loss.data.cpu().numpy()))
            if i_iter % args.save_pred_every == 0 and i_iter > args.save_start:
                print('save models ...')
                torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + str(i_iter) + '.pth'))

    end = timeit.default_timer()

    if args.local_rank == 0:
        Log.info("Training cost: "+ str(end - start) + 'seconds')
        Log.info("Save final models")
        torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + '_final' + '.pth'))
コード例 #9
0
ファイル: solver.py プロジェクト: minghanz/C3D_DORN
class Solver(object):
    def __init__(self):
        """
            :param config: easydict
        """
        self.version = __version__
        # logging.info("PyTorch Version {}, Solver Version {}".format(torch.__version__, self.version))
        self.distributed = False
        self.world_size = 1
        self.local_rank = 0
        self.epoch = 0
        self.iteration = 0
        self.config = None
        self.model, self.optimizer, self.lr_policy = None, None, None
        self.step_decay = 1
        self.filtered_keys = None

        if 'WORLD_SIZE' in os.environ:
            self.world_size = int(os.environ['WORLD_SIZE'])
            self.distributed = self.world_size > 1 or torch.cuda.device_count(
            ) > 1

        if self.distributed:
            dist.init_process_group(backend="nccl", init_method='env://')
            self.local_rank = dist.get_rank()
            torch.cuda.set_device(self.local_rank)
            logging.info(
                '[distributed mode] world size: {}, local rank: {}.'.format(
                    self.world_size, self.local_rank))
        else:
            logging.info('[Single GPU mode]')

    def _build_environ(self):
        if self.config['environ']['deterministic']:
            cudnn.benchmark = False
            cudnn.deterministic = True
            torch.set_printoptions(precision=10)
        else:
            cudnn.benchmark = True

        if self.config['apex']:
            assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

        # set random seed
        torch.manual_seed(self.config['environ']['seed'])
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.config['environ']['seed'])
        np.random.seed(self.config['environ']['seed'])
        random.seed(self.config['environ']['seed'])

        # grad clip settings
        self.grad_clip_params = self.config["solver"]["optimizer"].get(
            "grad_clip")
        self.use_grad_clip = True if self.grad_clip_params is not None else False
        if self.use_grad_clip:
            logging.info("Using grad clip and params is {}".format(
                self.grad_clip_params))
        else:
            logging.info("Not Using grad clip.")

    def init_from_scratch(self, config):
        t_start = time.time()
        self.config = config
        self._build_environ()
        # model and optimizer
        self.model = _get_model(self.config)
        self.filtered_keys = [
            p.name
            for p in inspect.signature(self.model.forward).parameters.values()
        ]
        # logging.info("filtered keys:{}".format(self.filtered_keys))
        # model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        model_params = []
        for params in self.model.optimizer_params():
            params["lr"] = self.config["solver"]["optimizer"]["params"][
                "lr"] * params["lr"]
            model_params.append(params)
        self.optimizer = _get_optimizer(config['solver']['optimizer'],
                                        model_params=model_params)

        self.lr_policy = _get_lr_policy(config['solver']['lr_policy'],
                                        optimizer=self.optimizer)
        self.step_decay = config['solver']['step_decay']

        if config['model'].get('pretrained_model') is not None:
            logging.info('loadding pretrained model from {}.'.format(
                config['model']['pretrained_model']))
            load_model(self.model,
                       config['model']['pretrained_model'],
                       distributed=False)

        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info(
                "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}."
                .format(self.config['apex']['opt_level'],
                        self.config['apex']['keep_batchnorm_fp32'],
                        self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config['apex']['opt_level'],
                keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                loss_scale=self.config['apex']["loss_scale"])
        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        t_end = time.time()
        logging.info(
            "Init trainer from scratch, Time usage: IO: {}".format(t_end -
                                                                   t_start))

    def init_from_checkpoint(self, continue_state_object):
        t_start = time.time()

        self.config = continue_state_object['config']
        self._build_environ()
        self.model = _get_model(self.config)
        self.filtered_keys = [
            p.name
            for p in inspect.signature(self.model.forward).parameters.values()
        ]
        # model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        model_params = []
        for params in self.model.optimizer_params():
            params["lr"] = self.config["solver"]["optimizer"]["params"][
                "lr"] * params["lr"]
            model_params.append(params)
        self.optimizer = _get_optimizer(self.config['solver']['optimizer'],
                                        model_params=model_params)
        self.lr_policy = _get_lr_policy(self.config['solver']['lr_policy'],
                                        optimizer=self.optimizer)

        load_model(self.model,
                   continue_state_object['model'],
                   distributed=False)
        self.model.cuda(self.local_rank)

        if self.distributed:
            self.model = convert_syncbn_model(self.model)

        if self.config['apex']['amp_used']:
            # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
            # for convenient interoperation with argparse.
            logging.info(
                "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}."
                .format(self.config['apex']['opt_level'],
                        self.config['apex']['keep_batchnorm_fp32'],
                        self.config['apex']['loss_scale']))
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                opt_level=self.config['apex']['opt_level'],
                keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"],
                loss_scale=self.config['apex']["loss_scale"])
            amp.load_state_dict(continue_state_object['amp'])

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        self.optimizer.load_state_dict(continue_state_object['optimizer'])
        self.lr_policy.load_state_dict(continue_state_object['lr_policy'])

        self.step_decay = self.config['solver']['step_decay']
        self.epoch = continue_state_object['epoch']
        self.iteration = continue_state_object["iteration"]

        del continue_state_object
        t_end = time.time()
        logging.info(
            "Init trainer from checkpoint, Time usage: IO: {}".format(t_end -
                                                                      t_start))

    def parse_kwargs(self, minibatch):
        kwargs = {
            k: v
            for k, v in minibatch.items() if k in self.filtered_keys
        }
        if torch.cuda.is_available():
            kwargs = tensor2cuda(kwargs)
        return kwargs

    def step(self, **kwargs):
        """
        :param kwargs:
        :return:
        """
        self.iteration += 1
        # loss = self.model(**kwargs)
        loss, loss_dorn, loss_c3d = self.model(**kwargs)
        loss_dorn /= self.step_decay
        loss_c3d /= self.step_decay

        loss /= self.step_decay

        # backward
        if self.distributed and self.config['apex']['amp_used']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if self.iteration % self.step_decay == 0:
            if self.use_grad_clip:
                clip_grad_norm_(self.model.parameters(),
                                **self.grad_clip_params)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.lr_policy.step(self.epoch)

        if self.distributed:
            reduced_loss = reduce_tensor(loss.data, self.world_size)
            reduced_loss_dorn = reduce_tensor(loss_dorn.data, self.world_size)
            reduced_loss_c3d = reduce_tensor(loss_c3d.data, self.world_size)
        else:
            reduced_loss = loss.data
            reduced_loss_dorn = loss_dorn.data
            reduced_loss_c3d = loss_c3d.data
        # return reduced_loss
        return reduced_loss, reduced_loss_dorn, reduced_loss_c3d

    def step_no_grad(self, **kwargs):
        with torch.no_grad():
            out = self.model(**kwargs)
        return out

    def before_epoch(self, epoch):
        synchronize()
        self.iteration = 0
        self.epoch = epoch
        self.model.train()
        # self.lr_policy.step(epoch)
        torch.cuda.empty_cache()

    def after_epoch(self, epoch=None):
        synchronize()
        self.model.eval()
        # gc.collect()
        torch.cuda.empty_cache()

    def save_checkpoint(self, path):
        if self.local_rank == 0:
            # logging.info("Saving checkpoint to file {}".format(path))
            t_start = time.time()

            state_dict = {}

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in self.model.state_dict().items():
                key = k
                if k.split('.')[0] == 'module':
                    key = k[7:]
                new_state_dict[key] = v

            if self.config['apex']['amp_used']:
                state_dict['amp'] = amp.state_dict()
            state_dict['config'] = self.config
            state_dict['model'] = new_state_dict
            state_dict['optimizer'] = self.optimizer.state_dict()
            state_dict['lr_policy'] = self.lr_policy.state_dict()
            state_dict['epoch'] = self.epoch
            state_dict['iteration'] = self.iteration

            t_iobegin = time.time()
            torch.save(state_dict, path)
            del state_dict
            del new_state_dict
            t_end = time.time()
            logging.info("Save checkpoint to file {}, "
                         "Time usage:\n\tprepare snapshot: {}, IO: {}".format(
                             path, t_iobegin - t_start, t_end - t_iobegin))

    def get_learning_rates(self):
        lrs = []
        for i in range(len(self.optimizer.param_groups)):
            lrs.append(self.optimizer.param_groups[i]['lr'])
        return lrs
コード例 #10
0
ファイル: trainer.py プロジェクト: OlegJakushkin/cdarts-mld
class CdartsTrainer(object):
    def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None,
                 regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True,
                 epochs=64, steps_per_epoch=128, fake_batch=128,  loss_alpha=2, loss_T=2, distributed=True,
                 log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs',
                 w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4,
                 nasnet_lr=0.2, local_rank=0, share_module=True):
        """
        Initialize a CdartsTrainer.

        Parameters
        ----------
        model_small : nn.Module
            PyTorch model to be trained. This is the search network of CDARTS.
        model_large : nn.Module
            PyTorch model to be trained. This is the evaluation network of CDARTS.
        criterion : callable
            Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
        loaders : list of torch.utils.data.DataLoader
            List of train data and valid data loaders, for training weights and architecture weights respectively.
        samplers : list of torch.utils.data.Sampler
            List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
            In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
        logger : logging.Logger
            The logger for logging. Will use nni logger by default (if logger is ``None``).
        regular_coeff : float
            The coefficient of regular loss.
        regular_ratio : float
            The ratio of regular loss.
        warmup_epochs : int
            The epochs to warmup the search network
        fix_head : bool
            ``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
        epochs : int
            Number of epochs planned for training.
        steps_per_epoch : int
            Steps of one epoch.
        fake_batch : int
            Batch*fake_batch is used for memory saving.
        loss_alpha : float
            The loss coefficient.
        loss_T : float
            The loss coefficient.
        distributed : bool
            ``True`` if using distributed training, else non-distributed training.
        log_frequency : int
            Step count per logging.
        grad_clip : float
            Gradient clipping for weights.
        interactive_type : string
            ``kl`` or ``smoothl1``.
        output_path : string
            Log storage path.
        w_lr : float
            Learning rate of the search network parameters.
        w_momentum : float
            Momentum of the search and the evaluation network.
        w_weight_decay : float
            The weight decay the search and the evaluation network parameters.
        alpha_lr : float
            Learning rate of the architecture parameters.
        alpha_weight_decay : float
            The weight decay the architecture parameters.
        nasnet_lr : float
            Learning rate of the evaluation network parameters.
        local_rank : int
            The number of thread.
        share_module : bool
            ``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
        """
        if logger is None:
            logger = logging.getLogger(__name__)
        train_loader, valid_loader = loaders
        train_sampler, valid_sampler = samplers

        self.train_loader = CyclicIterator(train_loader, train_sampler, distributed)
        self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed)

        self.regular_coeff = regular_coeff
        self.regular_ratio = regular_ratio
        self.warmup_epochs = warmup_epochs
        self.fix_head = fix_head
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        if self.steps_per_epoch is None:
            self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader))
        self.fake_batch = fake_batch

        self.loss_alpha = loss_alpha
        self.grad_clip = grad_clip
        if interactive_type == "kl":
            self.interactive_loss = InteractiveKLLoss(loss_T)
        elif interactive_type == "smoothl1":
            self.interactive_loss = nn.SmoothL1Loss()
        self.loss_T = loss_T
        self.distributed = distributed
        self.log_frequency = log_frequency
        self.main_proc = not distributed or local_rank == 0

        self.logger = logger
        self.checkpoint_dir = output_path
        if self.main_proc:
            os.makedirs(self.checkpoint_dir, exist_ok=True)
        if distributed:
            torch.distributed.barrier()

        self.model_small = model_small
        self.model_large = model_large
        if self.fix_head:
            for param in self.model_small.aux_head.parameters():
                param.requires_grad = False
            for param in self.model_large.aux_head.parameters():
                param.requires_grad = False

        self.mutator_small = RegularizedDartsMutator(self.model_small).cuda()
        self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda()
        self.criterion = criterion

        self.optimizer_small =apex.optimizers.FusedSGD(self.model_small.parameters(), w_lr,
                                               momentum=w_momentum, weight_decay=w_weight_decay)
        self.optimizer_large = apex.optimizers.FusedSGD(self.model_large.parameters(), nasnet_lr,
                                               momentum=w_momentum, weight_decay=w_weight_decay)
        self.optimizer_alpha = apex.optimizers.FusedAdam(self.mutator_small.parameters(), alpha_lr
                                               )

        if distributed:
            apex.parallel.convert_syncbn_model(self.model_small)
            apex.parallel.convert_syncbn_model(self.model_large)
            self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True)
            self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True)
            self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True)
            if share_module:
                self.model_small.callback_queued = True
                self.model_large.callback_queued = True
            # mutator large never gets optimized, so do not need parallelized

    def _warmup(self, phase, epoch):
        assert phase in [PHASE_SMALL, PHASE_LARGE]
        if phase == PHASE_SMALL:
            model, optimizer = self.model_small, self.optimizer_small
        elif phase == PHASE_LARGE:
            model, optimizer = self.model_large, self.optimizer_large
        model.train()
        meters = AverageMeterGroup()
        for step in range(self.steps_per_epoch):

            optimizer.zero_grad()
            totall_l =0
            totall_p =0
            for fb in range(self.fake_batch):
                x, y = next(self.train_loader)
                x, y = x.cuda(), y.cuda()
                logits_main, _ = model(x)
                loss = self.criterion(logits_main, y)/self.fake_batch
                loss.backward()
                totall_l += loss
                prec1,prec1 = accuracy(logits_main, y, topk=(1,1))
                prec1 = prec1/self.fake_batch
                totall_p += prec1


            self._clip_grad_norm(model)
            optimizer.step()

            metrics = {"prec1": totall_p, "loss": totall_l}
            metrics = reduce_metrics(metrics, self.distributed)
            meters.update(metrics)
            if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
                self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s)  %s", epoch + 1, self.epochs,
                                 step + 1, self.steps_per_epoch, phase, meters)

    def _clip_grad_norm(self, model):
        if isinstance(model, DistributedDataParallel):
            nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip)
        else:
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)

    def _reset_nan(self, parameters):
        with torch.no_grad():
            for param in parameters:
                for i, p in enumerate(param):
                    if p != p:  # equivalent to `isnan(p)`
                        param[i] = float("-inf")

    def _joint_train(self, epoch):
        meters = AverageMeterGroup()
        for step in range(self.steps_per_epoch):
            totall_lc = 0
            totall_lw = 0
            totall_li = 0
            totall_lr = 0

            loss_regular = self.mutator_small.reset_with_loss()
            reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
                    (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
            if loss_regular:
                loss_regular *= reg_decay

            samples_x = []
            samples_y = []
            criterion_l = []
            emsemble_logits_l = []

            def trn_l(totall_lc, totall_lw, totall_li, totall_lr):

                self.model_large.train()
                self.optimizer_large.zero_grad()

                for fb in range(self.fake_batch):
                    val_x, val_y = next(self.valid_loader)
                    val_x, val_y = val_x.cuda(), val_y.cuda()

                    logits_main, emsemble_logits_main = self.model_large(val_x)
                    cel = self.criterion(logits_main, val_y)
                    loss_weight = cel / (self.fake_batch)
                    loss_weight.backward(retain_graph=True)

                    criterion_l.append(cel.cpu())
                    emsemble_logits_l.append(emsemble_logits_main.cpu())

                    totall_lw += float(loss_weight)
                    samples_x.append(val_x.cpu())
                    samples_y.append(val_y.cpu())

                self._clip_grad_norm(self.model_large)
                self.optimizer_large.step()
                self.model_large.train(mode=False)

                return totall_lc, totall_lw, totall_li, totall_lr

            totall_lc, totall_lw, totall_li, totall_lr = trn_l(totall_lc, totall_lw, totall_li, totall_lr)
            def sleep(s):
                print("--" + str(s))
                time.sleep(2)
                print(torch.cuda.memory_summary())
                print("++" + str(s))

            def trn_s(totall_lc, totall_lw, totall_li, totall_lr):
                print("sts")
                self.model_small.cuda()
                self.model_small.train()
                self.optimizer_alpha.zero_grad()
                self.optimizer_small.zero_grad()
                i = 0;
                ls = []
                els = []
                sleep(0)
                def sc():
                    reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / (
                            (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0)
                    loss_regular = self.mutator_small.reset_with_loss()
                    if loss_regular:
                        loss_regular *= reg_decay
                    loss_regular.backward()
                    loss_regular = loss_regular.cpu().detach()
                sc()
                sleep(0.5)
                for i in range(len(samples_x)):
                    val_x = samples_x[i]
                    val_x = val_x.cuda()
                    val_y = samples_y[i]
                    val_y = val_y.cuda()


                    logits_search, emsemble_logits_search = self.model_small(val_x)
                    cls = self.criterion(logits_search, val_y)

                    ls.append(cls.cpu())
                    els.append(emsemble_logits_search.cpu())
                    val_x.cpu().detach()
                    val_y.cpu().detach()

                sleep(1)
                for i in range(len(samples_x)):
                    criterion_logits_main = criterion_l[i].cuda()
                    cls = ls[i].cuda()
                    emsemble_logits_search = els[i].cuda()
                    loss_weight = cls / (self.fake_batch)
                    totall_lw += float(loss_weight)
                    loss_cls = (cls + criterion_logits_main) / self.loss_alpha / self.fake_batch
                    loss_cls.backward(retain_graph=True)
                    totall_lc += float(loss_cls)
                    criterion_logits_main.cpu().detach()

                sleep(2)
                for i in range(len(samples_x)):
                    emsemble_logits_main = emsemble_logits_l[i].cuda()
                    emsemble_logits_search = els[i].cuda()
                    sleep(3)
                    loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * (
                                self.loss_T ** 2) * self.loss_alpha / self.fake_batch
                    loss_interactive.backward(retain_graph=True)
                    sleep(5)
                    emsemble_logits_search.cpu()
                    totall_li += float(loss_interactive)
                    totall_lr += float(loss_regular)
                    emsemble_logits_search.cpu().detach()
                    emsemble_logits_main.cpu().detach()
                    sleep(6)
                    i = i + 1


                self.optimizer_alpha.step()
                self._clip_grad_norm(self.model_small)
                self.optimizer_small.step()
                self.model_small.train(mode=False)
                samples_x.clear()
                samples_y.clear()
                criterion_l.clear()
                emsemble_logits_l.clear()
                return totall_lc, totall_lw, totall_li, totall_lr

            totall_lc, totall_lw, totall_li, totall_lr = trn_s(totall_lc, totall_lw, totall_li, totall_lr)



            metrics = {"loss_cls": totall_lc, "loss_interactive": totall_li,
                       "loss_regular": totall_lr, "loss_weight": totall_lw}
            #metrics = reduce_metrics(metrics, self.distributed)
            meters.update(metrics)

            if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
                self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint)  %s", epoch + 1, self.epochs,
                                 step + 1, self.steps_per_epoch, meters)

    def train(self):
        for epoch in range(self.epochs):
            if epoch < self.warmup_epochs:
                with torch.no_grad():  # otherwise grads will be retained on the architecture params
                    self.mutator_small.reset_with_loss()
                self._warmup(PHASE_SMALL, epoch)
            else:
                with torch.no_grad():
                    self.mutator_large.reset()
                self._warmup(PHASE_LARGE, epoch)
                self._joint_train(epoch)

            self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)),
                        os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch)))

    def export(self, file, genotype_file):
        if self.main_proc:
            mutator_export, genotypes = self.mutator_small.export(self.logger)
            with open(file, "w") as f:
                json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
            with open(genotype_file, "w") as f:
                f.write(str(genotypes))
コード例 #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--features_h5path",
        default="/coc/pskynet2/jlu347/multi-modal-bert/data/flick30k/flickr30k.h5",
    )

    # Required parameters
    parser.add_argument(
        "--val_file",
        default="data/flick30k/all_data_final_test_set0_2014.jsonline",
        type=str,
        help="The input train corpus.",
    )

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )

    parser.add_argument(
        "--pretrained_weight",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )

    parser.add_argument(
        "--output_dir",
        default="result",
        type=str,
        # required=True,
        help="The output directory where the model checkpoints will be written.",
    )

    parser.add_argument(
        "--config_file",
        default="config/bert_config.json",
        type=str,
        # required=True,
        help="The config file which specified the model details.",
    )
    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=30,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.",
    )

    parser.add_argument(
        "--train_batch_size",
        default=128,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=5e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=50,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )

    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--from_pretrained",
        action="store_true",
        help="Wheter the tensor is from pretrained.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--baseline",
        action="store_true",
        help="Wheter to use the baseline model (single bert).",
    )

    parser.add_argument(
        "--zero_shot", action="store_true", help="Wheter directly evaluate."
    )

    args = parser.parse_args()

    if args.baseline:
        from pytorch_pretrained_bert.modeling import BertConfig
        from multimodal_bert.bert import MultiModalBertForImageCaptionRetrieval
        from multimodal_bert.bert import BertForMultiModalPreTraining
    else:
        from multimodal_bert.multi_modal_bert import (
            MultiModalBertForImageCaptionRetrieval,
            BertConfig,
        )
        from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining

    print(args)
    if args.save_name is not "":
        timeStamp = args.save_name
    else:
        timeStamp = strftime("%d-%b-%y-%X-%a", gmtime())
        timeStamp += "_{:0>6d}".format(random.randint(0, 10e6))

    savePath = os.path.join(args.output_dir, timeStamp)

    if not os.path.exists(savePath):
        os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    # save all the hidden parameters.
    with open(os.path.join(savePath, "command.txt"), "w") as f:
        print(args, file=f)  # Python 3.x
        print("\n", file=f)
        print(config, file=f)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                args.gradient_accumulation_steps
            )
        )

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # train_examples = None
    num_train_optimization_steps = None

    print("Loading Train Dataset", args.val_file)

    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case
    )
    image_features_reader = ImageFeaturesH5Reader(args.features_h5path, True)
    eval_dset = COCORetreivalDatasetVal(args.val_file, image_features_reader, tokenizer)

    config.fast_mode = True
    if args.from_pretrained:
        if args.zero_shot:
            model = BertForMultiModalPreTraining.from_pretrained(
                args.pretrained_weight, config
            )
        else:
            model = MultiModalBertForImageCaptionRetrieval.from_pretrained(
                args.pretrained_weight, config, dropout_prob=0.1
            )
    else:
        if args.zero_shot:
            model = BertForMultiModalPreTraining.from_pretrained(
                args.bert_model, config
            )
        else:
            model = MultiModalBertForImageCaptionRetrieval.from_pretrained(
                args.bert_model, config, dropout_prob=0.1
            )

    if args.fp16:
        model.half()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    model.cuda()
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(eval_dset))
    logger.info("  Batch size = %d", args.train_batch_size)

    eval_dataloader = DataLoader(
        eval_dset,
        shuffle=False,
        batch_size=1,
        num_workers=args.num_workers,
        pin_memory=False,
    )

    startIterID = 0
    global_step = 0
    masked_loss_v_tmp = 0
    masked_loss_t_tmp = 0
    next_sentence_loss_tmp = 0
    loss_tmp = 0

    r1, r5, r10, medr, meanr = evaluate(args, model, eval_dataloader)
    print("finish evaluation, save result to %s")

    val_name = args.val_file.split("/")[-1]
    with open(os.path.join(savePath, val_name + "_result.txt"), "w") as f:
        print(
            "r1:%.3f, r5:%.3f, r10:%.3f, mder:%.3f, meanr:%.3f"
            % (r1, r5, r10, medr, meanr),
            file=f,
        )
コード例 #12
0
def main(opts):
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = opts.local_rank, torch.device(opts.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    task_name = f"{opts.task}-{opts.dataset}"
    logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/"
    if rank == 0:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize, step=opts.step)
    else:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=False)

    logger.print(f"Device: {device}")

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # xxx Set up dataloader
    train_dst, val_dst, test_dst, n_classes = get_dataset(opts)
    # reset the seed, this revert changes in random seed
    random.seed(opts.random_seed)

    train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size,
                                   sampler=DistributedSampler(train_dst, num_replicas=world_size, rank=rank),
                                   num_workers=opts.num_workers, drop_last=True)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                 sampler=DistributedSampler(val_dst, num_replicas=world_size, rank=rank),
                                 num_workers=opts.num_workers)
    logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, Val set: {len(val_dst)},"
                f" Test set: {len(test_dst)}, n_classes {n_classes}")
    logger.info(f"Total batch size is {opts.batch_size * world_size}")

    # xxx Set up model
    logger.info(f"Backbone: {opts.backbone}")

    step_checkpoint = None
    model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
    logger.info(f"[!] Model made with{'out' if opts.no_pretrained else ''} pre-trained")

    if opts.step == 0:  # if step 0, we don't need to instance the model_old
        model_old = None
    else:  # instance model_old
        model_old = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step - 1))

    if opts.fix_bn:
        model.fix_bn()

    logger.debug(model)

    # xxx Set up optimizer
    params = []
    if not opts.freeze:
        params.append({"params": filter(lambda p: p.requires_grad, model.body.parameters()),
                       'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.head.parameters()),
                   'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.cls.parameters()),
                   'weight_decay': opts.weight_decay})

    optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, max_iters=opts.epochs * len(train_loader), power=opts.lr_power)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    else:
        raise NotImplementedError
    logger.debug("Optimizer:\n%s" % optimizer)

    if model_old is not None:
        [model, model_old], optimizer = amp.initialize([model.to(device), model_old.to(device)], optimizer,
                                                       opt_level=opts.opt_level)
        model_old = DistributedDataParallel(model_old)
    else:
        model, optimizer = amp.initialize(model.to(device), optimizer, opt_level=opts.opt_level)

    # Put the model on GPU
    model = DistributedDataParallel(model, delay_allreduce=True)

    # xxx Load old model from old weights if step > 0!
    if opts.step > 0:
        # get model path
        if opts.step_ckpt is not None:
            path = opts.step_ckpt
        else:
            path = f"checkpoints/step/{task_name}_{opts.name}_{opts.step - 1}.pth"

        # generate model from path
        if os.path.exists(path):
            step_checkpoint = torch.load(path, map_location="cpu")
            model.load_state_dict(step_checkpoint['model_state'], strict=False)  # False because of incr. classifiers
            if opts.init_balanced:
                # implement the balanced initialization (new cls has weight of background and bias = bias_bkg - log(N+1)
                model.module.init_new_classifier(device)
            # Load state dict from the model state dict, that contains the old model parameters
            model_old.load_state_dict(step_checkpoint['model_state'], strict=True)  # Load also here old parameters
            logger.info(f"[!] Previous model loaded from {path}")
            # clean memory
            del step_checkpoint['model_state']
        elif opts.debug:
            logger.info(f"[!] WARNING: Unable to find of step {opts.step - 1}! Do you really want to do from scratch?")
        else:
            raise FileNotFoundError(path)
        # put the old model into distributed memory and freeze it
        for par in model_old.parameters():
            par.requires_grad = False
        model_old.eval()

    # xxx Set up Trainer
    trainer_state = None
    # if not first step, then instance trainer from step_checkpoint
    if opts.step > 0 and step_checkpoint is not None:
        if 'trainer_state' in step_checkpoint:
            trainer_state = step_checkpoint['trainer_state']

    # instance trainer (model must have already the previous step weights)
    trainer = Trainer(model, model_old, device=device, opts=opts, trainer_state=trainer_state,
                      classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))

    # xxx Handle checkpoint for current model (model old will always be as previous step or None)
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"], strict=True)
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        cur_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint['best_score']
        logger.info("[!] Model restored from %s" % opts.ckpt)
        # if we want to resume training, resume trainer from checkpoint
        if 'trainer_state' in checkpoint:
            trainer.load_state_dict(checkpoint['trainer_state'])
        del checkpoint
    else:
        if opts.step == 0:
            logger.info("[!] Train from scratch")

    # xxx Train procedure
    # print opts before starting training to log all parameters
    logger.add_table("Opts", vars(opts))

    if rank == 0 and opts.sample_num > 0:
        sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False)  # sample idxs for visualization
        logger.info(f"The samples id are {sample_ids}")
    else:
        sample_ids = None

    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])  # de-normalization for original images

    TRAIN = not opts.test
    val_metrics = StreamSegMetrics(n_classes)
    results = {}

    # check if random is equal here.
    logger.print(torch.randint(0,100, (1,1)))
    # train/val here
    while cur_epoch < opts.epochs and TRAIN:
        # =====  Train  =====
        model.train()

        epoch_loss = trainer.train(cur_epoch=cur_epoch, optim=optimizer,
                                   train_loader=train_loader, scheduler=scheduler, logger=logger)

        logger.info(f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0]+epoch_loss[1]},"
                    f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]}")

        # =====  Log metrics on Tensorboard =====
        logger.add_scalar("E-Loss", epoch_loss[0]+epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch)

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            logger.info("validate on val set...")
            model.eval()
            val_loss, val_score, ret_samples = trainer.validate(loader=val_loader, metrics=val_metrics,
                                                                ret_samples_ids=sample_ids, logger=logger)

            logger.print("Done validation")
            logger.info(f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss[0]+val_loss[1]},"
                        f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")

            logger.info(val_metrics.to_str(val_score))

            # =====  Save Best Model  =====
            if rank == 0:  # save best model at the last iteration
                score = val_score['Mean IoU']
                # best model to build incremental steps
                save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                          model, trainer, optimizer, scheduler, cur_epoch, score)
                logger.info("[!] Checkpoint saved.")

            # =====  Log metrics on Tensorboard =====
            # visualize validation score and samples
            logger.add_scalar("V-Loss", val_loss[0]+val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-reg", val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-cls", val_loss[0], cur_epoch)
            logger.add_scalar("Val_Overall_Acc", val_score['Overall Acc'], cur_epoch)
            logger.add_scalar("Val_MeanIoU", val_score['Mean IoU'], cur_epoch)
            logger.add_table("Val_Class_IoU", val_score['Class IoU'], cur_epoch)
            logger.add_table("Val_Acc_IoU", val_score['Class Acc'], cur_epoch)
            # logger.add_figure("Val_Confusion_Matrix", val_score['Confusion Matrix'], cur_epoch)

            # keep the metric to print them at the end of training
            results["V-IoU"] = val_score['Class IoU']
            results["V-Acc"] = val_score['Class Acc']

            for k, (img, target, lbl) in enumerate(ret_samples):
                img = (denorm(img) * 255).astype(np.uint8)
                target = label2color(target).transpose(2, 0, 1).astype(np.uint8)
                lbl = label2color(lbl).transpose(2, 0, 1).astype(np.uint8)

                concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                logger.add_image(f'Sample_{k}', concat_img, cur_epoch)

        cur_epoch += 1

    # =====  Save Best Model at the end of training =====
    if rank == 0 and TRAIN:  # save best model at the last iteration
        # best model to build incremental steps
        save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                  model, trainer, optimizer, scheduler, cur_epoch, best_score)
        logger.info("[!] Checkpoint saved.")

    torch.distributed.barrier()

    # xxx From here starts the test code
    logger.info("*** Test the model on all seen classes...")
    # make data loader
    test_loader = data.DataLoader(test_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                  sampler=DistributedSampler(test_dst, num_replicas=world_size, rank=rank),
                                  num_workers=opts.num_workers)

    # load best model
    if TRAIN:
        model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
        # Put the model on GPU
        model = DistributedDataParallel(model.cuda(device))
        ckpt = f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth"
        checkpoint = torch.load(ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"])
        logger.info(f"*** Model restored from {ckpt}")
        del checkpoint
        trainer = Trainer(model, None, device=device, opts=opts)

    model.eval()

    val_loss, val_score, _ = trainer.validate(loader=test_loader, metrics=val_metrics, logger=logger)
    logger.print("Done test")
    logger.info(f"*** End of Test, Total Loss={val_loss[0]+val_loss[1]},"
                f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")
    logger.info(val_metrics.to_str(val_score))
    logger.add_table("Test_Class_IoU", val_score['Class IoU'])
    logger.add_table("Test_Class_Acc", val_score['Class Acc'])
    logger.add_figure("Test_Confusion_Matrix", val_score['Confusion Matrix'])
    results["T-IoU"] = val_score['Class IoU']
    results["T-Acc"] = val_score['Class Acc']
    logger.add_results(results)

    logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'], opts.step)
    logger.add_scalar("T_MeanIoU", val_score['Mean IoU'], opts.step)
    logger.add_scalar("T_MeanAcc", val_score['Mean Acc'], opts.step)

    logger.close()
コード例 #13
0
class NetworkFactory(object):
    def __init__(self, system_config, model, distributed=False, gpu=None):
        super(NetworkFactory, self).__init__()

        self.system_config = system_config

        self.gpu = gpu
        self.model = DummyModule(model)
        self.loss = model.loss
        self.network = Network(self.model, self.loss)

        if distributed:
            from apex.parallel import DistributedDataParallel, convert_syncbn_model
            torch.cuda.set_device(gpu)
            self.network = self.network.cuda(gpu)
            self.network = convert_syncbn_model(self.network)
            self.network = DistributedDataParallel(self.network)
        else:
            # self.network = DataParallel(self.network, chunk_sizes=system_config.chunk_sizes)
            pass

        total_params = 0
        for params in self.model.parameters():
            num_params = 1
            for x in params.size():
                num_params *= x
            total_params += num_params
        print("total parameters: {}".format(total_params))

        if system_config.opt_algo == "adam":
            self.optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()))
        elif system_config.opt_algo == "sgd":
            self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                    self.model.parameters()),
                                             lr=system_config.learning_rate,
                                             momentum=0.9,
                                             weight_decay=0.0001)
        else:
            raise ValueError("unknown optimizer")

    def cuda(self):
        self.model.cuda()

    def train_mode(self):
        self.network.train()

    def eval_mode(self):
        self.network.eval()

    def _t_cuda(self, xs):
        if type(xs) is list:
            return [x.cuda(self.gpu, non_blocking=True) for x in xs]
        return xs.cuda(self.gpu, non_blocking=True)

    def train(self, xs, ys, **kwargs):
        xs = [self._t_cuda(x) for x in xs]
        ys = [self._t_cuda(y) for y in ys]

        self.optimizer.zero_grad()
        loss = self.network(xs, ys)
        loss = loss.mean()
        loss.backward()
        self.optimizer.step()

        return loss

    def validate(self, xs, ys, **kwargs):
        with torch.no_grad():
            xs = [self._t_cuda(x) for x in xs]
            ys = [self._t_cuda(y) for y in ys]

            loss = self.network(xs, ys)
            loss = loss.mean()
            return loss

    def test(self, xs, **kwargs):
        with torch.no_grad():
            xs = [self._t_cuda(x) for x in xs]
            return self.model(*xs, **kwargs)

    def set_lr(self, lr):
        print("setting learning rate to: {}".format(lr))
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def load_pretrained_params(self, pretrained_model):
        print("loading from {}".format(pretrained_model))
        with open(pretrained_model, "rb") as f:
            params = torch.load(f)
            self.model.load_state_dict(params)

    def load_params(self, iteration):
        cache_file = self.system_config.snapshot_file.format(iteration)
        print("loading model from {}".format(cache_file))
        with open(cache_file, "rb") as f:
            params = torch.load(f)
            self.model.load_state_dict(params)

    def save_params(self, iteration):
        cache_file = self.system_config.snapshot_file.format(iteration)
        print("saving model to {}".format(cache_file))
        with open(cache_file, "wb") as f:
            params = self.model.state_dict()
            torch.save(params, f)
コード例 #14
0
class TRans2InfoMax(Trans2Net):
    def __init__(self, cfg, writer=None):
        super(TRans2InfoMax, self).__init__(cfg, writer)

    def _define_networks(self):

        self.net = networks.Source_Model(self.cfg, device=self.device)
        self.cross_encoder = networks.Cross_Model(self.cfg, device=self.device)
        self.d_distribute = networks.GANDiscriminator(self.cfg,
                                                      device=self.device)
        self.model_names = ['net', 'cross_encoder', 'd_distribute']

        networks.print_network(self.net)
        networks.print_network(self.cross_encoder)
        networks.print_network(self.d_distribute)

        if 'PIX2PIX' in self.cfg.LOSS_TYPES:
            criterion_pix2pix = torch.nn.L1Loss()
            self.cross_encoder.set_pix2pix_criterion(criterion_pix2pix)

    def set_device(self):

        if not self.cfg.MULTIPROCESSING_DISTRIBUTED:
            self.net = nn.DataParallel(self.net).to(self.device)
            self.cross_encoder = nn.DataParallel(self.cross_encoder).to(
                self.device)
            self.d_distribute = nn.DataParallel(self.d_distribute).to(
                self.device)

    def set_optimizer(self, cfg):

        self.optimizers = []

        # if self.cfg.RESUME:
        #     self.params_list = []
        #     self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4]
        #     self.modules_sc = [self.net.evaluator]
        #
        #
        #     for module in self.modules_ft:
        #         self.params_list.append(dict(params=module.parameters(), lr=cfg.LR))
        #     for module in self.modules_sc:
        #         self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 10))
        #     self.optimizer_g = torch.optim.Adam(self.params_list, lr=cfg.LR, betas=(0.5, 0.999))
        # else:
        self.optimizer_g = torch.optim.Adam(self.net.parameters(),
                                            lr=cfg.LR,
                                            betas=(0.5, 0.999))
        self.optimizer_c = torch.optim.Adam(self.cross_encoder.parameters(),
                                            lr=cfg.LR,
                                            betas=(0.5, 0.999))
        self.optimizer_d = torch.optim.SGD(self.d_distribute.parameters(),
                                           lr=cfg.LR,
                                           momentum=0.9,
                                           weight_decay=0.0005)

        if cfg.MULTIPROCESSING_DISTRIBUTED:
            if cfg.USE_APEX:
                self.net, self.optimizer_g = apex.amp.initialize(
                    self.net.cuda(), self.optimizer_g, opt_level=cfg.opt_level)
                self.cross_encoder, self.optimizer_c = apex.amp.initialize(
                    self.cross_encoder.cuda(),
                    self.optimizer_c,
                    opt_level=cfg.opt_level)
                self.d_distribute, self.optimizer_d = apex.amp.initialize(
                    self.d_distribute.cuda(),
                    self.optimizer_d,
                    opt_level=cfg.opt_level)
                self.net = DDP(self.net)
                self.cross_encoder = DDP(self.cross_encoder)
                self.d_distribute = DDP(self.d_distribute)
            else:
                self.net = torch.nn.parallel.DistributedDataParallel(
                    self.net.cuda(), device_ids=[cfg.gpu])
                self.cross_encoder = torch.nn.parallel.DistributedDataParallel(
                    self.cross_encoder.cuda(), device_ids=[cfg.gpu])
                self.d_distribute = torch.nn.parallel.DistributedDataParallel(
                    self.d_distribute.cuda(), device_ids=[cfg.gpu])

        self.optimizers.append(self.optimizer_d)
        self.optimizers.append(self.optimizer_g)
        self.optimizers.append(self.optimizer_c)

    # def get_patch(self, img):
    #
    #     # Input of the function is a tensor [B, C, H, W]
    #     # Output of the functions is a tensor [B * 49, C, 64, 64]
    #
    #     patch_batch = None
    #     all_patches_list = []
    #
    #     for y_patch in range(3):
    #         for x_patch in range(3):
    #             y1 = y_patch * 64
    #             y2 = y1 + 128
    #
    #             x1 = x_patch * 64
    #             x2 = x1 + 128
    #
    #             img_patches = img[:, :, y1:y2, x1:x2]  # Batch(img_idx in batch), channels xrange, yrange
    #             img_patches = img_patches.unsqueeze(dim=1)
    #             all_patches_list.append(img_patches)
    #
    #             # print(patch_batch.shape)
    #     all_patches_tensor = torch.cat(all_patches_list, dim=1)
    #
    #     patches_per_image = []
    #     for b in range(all_patches_tensor.shape[0]):
    #         patches_per_image.append(all_patches_tensor[b])
    #
    #     patch_batch = torch.cat(patches_per_image, dim=0)
    #     return patch_batch

    # encoder-decoder branch
    def _forward(self, class_only=False):

        # if self.phase == 'train':
        #     self.source_modal = self.get_patch(self.source_modal)
        #     self.target_modal = self.get_patch(self.target_modal)
        #
        #     self.source_modal.view(self.batch_size, 3, 3, -1)

        if self.label is not None:
            label = self.label
        else:
            label = None
        self.result_g = self.net(self.source_modal,
                                 target=self.target_modal,
                                 label=label,
                                 class_only=class_only)
        if self.phase == 'train' and not self.cfg.NO_TRANS:
            self.result_c = self.cross_encoder(self.result_g['gen_cross'],
                                               target=self.target_modal)

    def _optimize(self, iter):

        self._forward()

        if 'GAN' in self.cfg.LOSS_TYPES:

            self.set_requires_grad([self.cross_encoder, self.net], False)
            self.set_requires_grad(self.d_distribute, True)
            fake_d = torch.cat(
                (self.result_c['feat_gen'], self.result_c['feat_target']), 1)
            real_d = torch.cat(
                (self.result_c['feat_target'], self.result_c['feat_target']),
                1)
            # fake_d = self.result_c['feat_gen']
            # real_d = self.result_c['feat_target']

            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                loss_d_fake = self.d_distribute(fake_d.detach(), False)
                loss_d_true = self.d_distribute(real_d.detach(), True)
            else:
                loss_d_fake = self.d_distribute(fake_d.detach(), False).mean()
                loss_d_true = self.d_distribute(real_d.detach(), True).mean()

            loss_d = (loss_d_fake + loss_d_true) * 0.5
            self.loss_meters['TRAIN_GAN_D_LOSS'].update(
                loss_d.item(), self.batch_size)

            self.optimizer_d.zero_grad()
            if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED:
                with apex.amp.scale_loss(loss_d,
                                         self.optimizer_d) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_d.backward()
            self.optimizer_d.step()

        # G
        loss_g = self._construct_loss(iter)
        self.set_requires_grad([self.cross_encoder, self.net], True)
        if self.d_distribute is not None:
            self.set_requires_grad(self.d_distribute, False)

        self.optimizer_c.zero_grad()
        self.optimizer_g.zero_grad()
        if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED:
            with apex.amp.scale_loss(
                    loss_g,
                [self.optimizer_c, self.optimizer_g]) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_g.backward()
        self.optimizer_c.step()
        self.optimizer_g.step()

    def _construct_loss(self, iter=None):

        loss_total = torch.zeros(1).to(self.device)
        # decay_coef = 1
        decay_coef = (iter / self.cfg.NITER_TOTAL)  # small to big

        if 'PIX2PIX' in self.cfg.LOSS_TYPES:

            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                local_loss = self.result_c[
                    'pix2pix_loss'] * self.cfg.ALPHA_LOCAL
                loss_total += local_loss
            else:
                local_loss = self.result_c['pix2pix_loss'].mean(
                ) * self.cfg.ALPHA_LOCAL

            self.loss_meters['TRAIN_PIX2PIX_LOSS'].update(
                local_loss.item(), self.batch_size)

        if 'PRIOR' in self.cfg.LOSS_TYPES:
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                prior_loss = self.result_g['prior_loss'] * self.cfg.ALPHA_PRIOR
                loss_total += prior_loss
            else:
                prior_loss = self.result_g['prior_loss'].mean(
                ) * self.cfg.ALPHA_PRIOR

            self.loss_meters['TRAIN_PRIOR_LOSS'].update(
                prior_loss.item(), self.batch_size)

        if 'CROSS' in self.cfg.LOSS_TYPES:
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:

                cross_loss = self.result_c['cross_loss'] * self.cfg.ALPHA_CROSS

                # cross_loss_self = self.result_c['cross_loss_self'] * self.cfg.ALPHA_CROSS * 0.2
            else:
                cross_loss = self.result_c['cross_loss'].mean(
                ) * self.cfg.ALPHA_CROSS
                # cross_loss_self = self.result_c['cross_loss_self'].mean() * self.cfg.ALPHA_CROSS * decay_coef

            loss_total += cross_loss
            # loss_total += cross_loss_self
            # loss_total += cross_loss

            self.loss_meters['TRAIN_CROSS_LOSS'].update(
                cross_loss.item(), self.batch_size)
            # self.loss_meters['TRAIN_CROSS_LOSS_SELF'].update(cross_loss_self.item(), self.batch_size)

        if 'H**O' in self.cfg.LOSS_TYPES:
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                homo_loss = self.result_g['homo_loss'] * self.cfg.ALPHA_CROSS
                loss_total += homo_loss
            else:
                homo_loss = self.result_g['homo_loss'].mean(
                ) * self.cfg.ALPHA_CROSS

            self.loss_meters['TRAIN_HOMO_LOSS'].update(homo_loss.item(),
                                                       self.batch_size)

        if 'CLS' in self.cfg.LOSS_TYPES:

            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                cls_loss = self.result_g['cls_loss']
            else:
                cls_loss = self.result_g['cls_loss'].mean()
            self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss.item(),
                                                      self.batch_size)

            loss_total += cls_loss

        if 'GAN' in self.cfg.LOSS_TYPES:

            # real_g = self.result_c['feat_gen']
            real_g = torch.cat(
                (self.result_c['feat_gen'], self.result_c['feat_target']), 1)
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                loss_gan_g = self.d_distribute(real_g,
                                               True) * self.cfg.ALPHA_GAN
            else:
                loss_gan_g = self.d_distribute(
                    real_g, True).mean() * self.cfg.ALPHA_GAN
            self.loss_meters['TRAIN_GAN_G_LOSS'].update(
                loss_gan_g.item(), self.batch_size)

            loss_total += loss_gan_g

        return loss_total

    def set_log_data(self, cfg):

        super().set_log_data(cfg)
        self.log_keys = [
            'TRAIN_CROSS_LOSS', 'TRAIN_CROSS_LOSS_SELF', 'TRAIN_HOMO_LOSS',
            'TRAIN_LOCAL_LOSS', 'TRAIN_PRIOR_LOSS', 'INTERSECTION_MLP',
            'LABEL_MLP', 'INTERSECTION_LIN', 'LABEL_LIN', 'VAL_CLS_ACC_MLP',
            'VAL_CLS_MEAN_ACC_MLP', 'TRAIN_GAN_D_LOSS', 'TRAIN_GAN_G_LOSS',
            'TRAIN_CONTRASTIVE_LOSS'
        ]
        for item in self.log_keys:
            self.loss_meters[item] = AverageMeter()

    def evaluate(self):

        # evaluate model on test_loader
        self.net.eval()
        self.phase = 'test'

        intersection_meter_mlp = self.loss_meters['INTERSECTION_MLP']
        target_meter_mlp = self.loss_meters['LABEL_MLP']

        for i, data in enumerate(self.val_loader):
            self.set_input(data)
            with torch.no_grad():
                self._forward(class_only=True)

                pred = self.result_g['pred'].data.max(1)[1]
                # lgt_glb_mlp = lgt_glb_mlp
                # lgt_glb_lin = lgt_glb_lin.data.max(1)[1]
                # [lgt_glb_mlp, lgt_glb_lin] = self.result_g['pred']
                # lgt_glb_mlp = lgt_glb_mlp.data.max(1)[1]
                # lgt_glb_lin = lgt_glb_lin.data.max(1)[1]

            intersection_mlp, union_mlp, label_mlp = util.intersectionAndUnionGPU(
                pred, self.label, self.cfg.NUM_CLASSES)
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                dist.all_reduce(intersection_mlp), dist.all_reduce(
                    union_mlp), dist.all_reduce(label_mlp)

            intersection_mlp, union_mlp, label_mlp = intersection_mlp.cpu(
            ).numpy(), union_mlp.cpu().numpy(), label_mlp.cpu().numpy()

            intersection_meter_mlp.update(intersection_mlp, self.batch_size)
            target_meter_mlp.update(label_mlp, self.batch_size)

        # Mean ACC
        allAcc_mlp = sum(
            intersection_meter_mlp.sum) / (sum(target_meter_mlp.sum) + 1e-10)
        accuracy_class_mlp = intersection_meter_mlp.sum / (
            target_meter_mlp.sum + 1e-10)
        mAcc_mlp = np.mean(accuracy_class_mlp)
        self.loss_meters['VAL_CLS_ACC_MLP'].update(allAcc_mlp)
        self.loss_meters['VAL_CLS_MEAN_ACC_MLP'].update(mAcc_mlp)

    def write_loss(self, phase, global_step):

        task = self.cfg.TASK_TYPE
        self.writer.add_image(task + '/rgb',
                              torchvision.utils.make_grid(
                                  self.source_modal[:6].clone().cpu().data,
                                  3,
                                  normalize=True),
                              global_step=global_step)
        if phase == 'train':

            if not self.cfg.NO_TRANS:

                for k, v in self.result_g.items():
                    if 'gen' in k:
                        # if isinstance(self.result_g[k], list):
                        #     for i, (gen, _depth) in enumerate(zip(self.result_g['gen'], self.target_modal)):
                        #         self.writer.add_image(task + '/' + k + str(self.cfg.FINE_SIZE[0] / pow(2, i)),
                        #                          torchvision.utils.make_grid(gen[:6].clone().cpu().data, 3,
                        #                                                      normalize=True),
                        #                          global_step=global_step)
                        #         self.writer.add_image(task + '/target' + str(self.cfg.FINE_SIZE[0] / pow(2, i)),
                        #                          torchvision.utils.make_grid(_depth[:6].clone().cpu().data, 3,
                        #                                                      normalize=True),
                        #                          global_step=global_step)
                        # else:
                        self.writer.add_image(
                            task + '/' + k,
                            torchvision.utils.make_grid(
                                self.result_g[k][:6].clone().cpu().data,
                                3,
                                normalize=True),
                            global_step=global_step)

                self.writer.add_image(
                    task + '/target',
                    torchvision.utils.make_grid(
                        self.target_modal[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                # self.writer.add_image(task + '/target_neg',
                #                       torchvision.utils.make_grid(self.target_modal_neg[:6].clone().cpu().data, 3,
                #
                #                                                   normalize=True), global_step=global_step)

            self.writer.add_scalar(task + '/LR',
                                   self.optimizer_g.param_groups[0]['lr'],
                                   global_step=global_step)

            for k, v in self.loss_meters.items():
                if 'LOSS' in k and v.avg > 0:
                    self.writer.add_scalar(task + '/' + k,
                                           v.avg,
                                           global_step=global_step)

        elif phase == 'test':

            for k, v in self.loss_meters.items():
                if ('MEAN' in k or 'ACC' in k) and v.val > 0:
                    self.writer.add_scalar(task + '/' + k,
                                           v.val * 100.0,
                                           global_step=global_step)
コード例 #15
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    # Data files for VQA task.
    parser.add_argument("--features_h5path", default="data/coco/test2015.h5")
    parser.add_argument(
        "--train_file",
        default="data/VQA/training",
        type=str,
        # required=True,
        help="The input train corpus.",
    )
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )

    parser.add_argument(
        "--pretrained_weight",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )

    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        # required=True,
        help=
        "The output directory where the model checkpoints will be written.",
    )

    parser.add_argument(
        "--config_file",
        default="config/bert_config.json",
        type=str,
        # required=True,
        help="The config file which specified the model details.",
    )
    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=30,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.",
    )

    parser.add_argument("--use_location",
                        action="store_true",
                        help="whether use location.")
    parser.add_argument(
        "--train_batch_size",
        default=128,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=5e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=30,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )

    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=20,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--from_pretrained",
        action="store_true",
        help="Wheter the tensor is from pretrained.",
    )
    parser.add_argument("--save_name",
                        default="",
                        type=str,
                        help="save name for training.")
    parser.add_argument(
        "--baseline",
        action="store_true",
        help="Wheter to use the baseline model (single bert).",
    )
    parser.add_argument("--split",
                        default="test",
                        type=str,
                        help="train or trainval.")

    parser.add_argument(
        "--use_chunk",
        default=0,
        type=float,
        help="whether use chunck for parallel training.",
    )
    args = parser.parse_args()

    if args.baseline:
        from pytorch_pretrained_bert.modeling import BertConfig
        from multimodal_bert.bert import MultiModalBertForVQA
    else:
        from multimodal_bert.multi_modal_bert import MultiModalBertForVQA, BertConfig

    print(args)
    if args.save_name is not "":
        timeStamp = args.save_name
    else:
        timeStamp = strftime("%d-%b-%y-%X-%a", gmtime())
        timeStamp += "_{:0>6d}".format(random.randint(0, 10e6))

    savePath = os.path.join(args.output_dir, timeStamp)

    if not os.path.exists(savePath):
        os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    # save all the hidden parameters.
    with open(os.path.join(savePath, "command.txt"), "w") as f:
        print(args, file=f)  # Python 3.x
        print("\n", file=f)
        print(config, file=f)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # train_examples = None
    num_train_optimization_steps = None

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    image_features_reader = ImageFeaturesH5Reader(args.features_h5path, True)

    if args.split == "minval":
        eval_dset = VQAClassificationDataset("minval",
                                             image_features_reader,
                                             tokenizer,
                                             dataroot="data/VQA")
    elif args.split == "test":
        eval_dset = VQAClassificationDataset("test",
                                             image_features_reader,
                                             tokenizer,
                                             dataroot="data/VQA")
    elif args.split == "val":
        eval_dset = VQAClassificationDataset("val",
                                             image_features_reader,
                                             tokenizer,
                                             dataroot="data/VQA")
    elif args.split == "test-dev":
        eval_dset = VQAClassificationDataset("test-dev",
                                             image_features_reader,
                                             tokenizer,
                                             dataroot="data/VQA")

    num_labels = eval_dset.num_ans_candidates
    if args.from_pretrained:
        model = MultiModalBertForVQA.from_pretrained(args.pretrained_weight,
                                                     config,
                                                     num_labels=num_labels)
    else:
        model = MultiModalBertForVQA.from_pretrained(args.bert_model,
                                                     config,
                                                     num_labels=num_labels)

    if args.fp16:
        model.half()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = DataParallel(model, use_chuncks=args.use_chunk)

    model.cuda()

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_dset))
    logger.info("  Batch size = %d", args.train_batch_size)

    eval_dataloader = DataLoader(
        eval_dset,
        shuffle=False,
        batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    startIterID = 0
    global_step = 0
    masked_loss_v_tmp = 0
    masked_loss_t_tmp = 0
    next_sentence_loss_tmp = 0
    loss_tmp = 0
    start_t = timer()

    model.train(False)
    eval_score, bound = evaluate(args, model, eval_dataloader)
    logger.info("\teval score: %.2f (%.2f)" % (100 * eval_score, 100 * bound))
コード例 #16
0
def main_worker(gpu, ngpus_per_node, args):
    
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(gpu)

    #################################################data
    PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/')
    WEIGHTS_PATH = Path('./weights/train_1024/')
    
    WEIGHTS_PATH.mkdir(exist_ok=True)
    batch_size = 60
    normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std)
    
    train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()])
    train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize]))	
    val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize]))	
    test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize]))
    #######################################
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset)
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=train_sampler)
    #######################################
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset)
    val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, num_workers=2,pin_memory=True,sampler=val_sampler)
    #######################################
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset)
    test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=test_sampler)
    #######################################
    
    print("Train: %d" %len(train_loader.dataset.imgs))
    print("Val: %d" %len(val_loader.dataset.imgs))
    print("Test: %d" %len(test_loader.dataset.imgs))
    print("Classes: %d" % len(train_loader.dataset.classes))
    
    inputs, targets = next(iter(train_loader))
    print("Inputs: ", inputs.size())
    print("Targets: ", targets.size())
    # utils.imgs.view_image(inputs[0])
    # utils.imgs.view_annotated(targets[0])
    EE = 4
    device = 'cuda'
    EE_size = 256
    LR = 1e-3
    model = UNet4(n_channels=3, n_classes=4)
    # gpu = args.local_rank
    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    # model = model.to(device)
    # model = torch.nn.DataParallel(model).cuda()
    ########################################
    optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4)
    model, optimizer = amp.initialize(model,optimizer)
    model = DistributedDataParallel(model)
    ###################################################
    
    # print('EE, model', EE, model)
    pred_dir = './train_PG_pred/'
    FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test'

    #############################

    EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024#
    
    epoch_num = 10000
    best_loss = 1.
    best_dice = 0.
    LR_DECAY = 0.95
    DECAY_EVERY_N_EPOCHS = 10
    criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda(gpu)
    cudnn.benchmark = True
    for epoch in range(1, epoch_num):
        start_time = datetime.datetime.now()
        print('start_time',start_time)
        model = model.cuda()
        
        ##################################################
        ### Train ###
        trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size)
        print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE))    
        ## Test ###
        val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size)   
        print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE))
        ### Checkpoint ###    
        DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size)
        print('-----------test_dice',DICE1)
        if best_dice < DICE1:
	        # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE)
	        best_dice = DICE1
        ### Adjust Lr ###
        adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS)
        end_time =  datetime.datetime.now()
        print('end_time', end_time)
        print('time', (end_time - start_time).seconds)
コード例 #17
0
class Trans2Net(BaseModel):
    def __init__(self, cfg, writer=None, batch_norm=nn.BatchNorm2d):
        super(Trans2Net, self).__init__(cfg)

        super().__init__(cfg)
        self.phase = cfg.PHASE
        self.trans = not cfg.NO_TRANS
        self.content_model = None
        self.writer = writer
        self.batch_size_train = cfg.BATCH_SIZE_TRAIN
        self.batch_size_val = cfg.BATCH_SIZE_VAL
        self.batch_norm = batch_norm
        self._define_networks()
        self.params_list = []
        # self.set_criterion(cfg)

    def _define_networks(self):

        networks.batch_norm = self.batch_norm
        self.net = networks.define_netowrks(self.cfg, device=self.device)

        self.model_names = ['net']

        if 'GAN' in self.cfg.LOSS_TYPES:
            self.discriminator = networks.GANDiscriminator_Image(
                self.cfg, device=self.device)
            self.model_names.append('discriminator')

        # if 'PSP' in cfg.MODEL:
        #     self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4]
        #     self.modules_sc = [self.net.ppm, self.net.cls, self.net.aux, self.net.score_aux1, self.net.score_aux2]
        #
        #     if self.trans:
        #         self.modules_ft.extend(
        #             [self.net.up0, self.net.up1, self.net.up2, self.net.up3,
        #              self.net.up4, self.net.up5, self.net.up_seg])
        #
        #     for module in self.modules_sc:
        #         self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 5))
        #     for module in self.modules_ft:
        #         self.params_list.append(dict(params=module.parameters(), lr=cfg.LR))

        # else:
        #
        #     self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4]
        #     self.modules_sc = [self.net.score_head, self.net.score_aux1, self.net.score_aux2]
        #     if self.trans:
        #         self.modules_sc.extend(
        #             [self.net.up1, self.net.up2, self.net.up3, self.net.up4, self.net.up5, self.net.up_image])

        # for module in self.modules_sc:
        #     self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 5))
        # for module in self.modules_ft:
        #     self.params_list.append(dict(params=module.parameters(), lr=cfg.LR))

        if self.cfg.USE_FAKE_DATA or self.cfg.USE_COMPL_DATA:
            print('Use fake data: sample model is {0}'.format(
                self.cfg.SAMPLE_MODEL_PATH))
            print('fake ratio:', self.cfg.FAKE_DATA_RATE)
            cfg_sample = copy.deepcopy(self.cfg)
            cfg_sample.USE_FAKE_DATA = False
            cfg_sample.USE_COMPL_DATA = False
            cfg_sample.NO_TRANS = False
            cfg_sample.MODEL = 'trecg_compl'
            model = networks.define_netowrks(cfg_sample, device=self.device)
            checkpoint_path = os.path.join(self.cfg.CHECKPOINTS_DIR,
                                           self.cfg.SAMPLE_MODEL_PATH)
            self._load_checkpoint(model,
                                  checkpoint_path,
                                  key='net',
                                  keep_fc=False)

            # for mit 67
            # self.net = copy.deepcopy(model.compl_net)

            model.eval()
            if self.cfg.USE_COMPL_DATA:
                self.net.set_sample_model(model)
            else:
                self.sample_model = nn.DataParallel(model).to(self.device)

        networks.print_network(self.net)

        # print('Use fake data: sample model is {0}'.format(cfg.SAMPLE_MODEL_PATH))
        # print('fake ratio:', cfg.FAKE_DATA_RATE)
        # sample_model_path = cfg.SAMPLE_MODEL_PATH
        # cfg_sample = copy.deepcopy(cfg)
        # cfg_sample.USE_FAKE_DATA = False
        # model = networks.define_netowrks(cfg_sample, device=self.device)
        # self.load_checkpoint(net=model, checkpoint_path=sample_model_path)
        # model.eval()
        # self.sample_model = nn.DataParallel(model).to(self.device)

    def set_device(self):

        if not self.cfg.MULTIPROCESSING_DISTRIBUTED:
            self.net = nn.DataParallel(self.net).to(self.device)
            if 'GAN' in self.cfg.LOSS_TYPES:
                self.discriminator = nn.DataParallel(self.discriminator).to(
                    self.device)

    def _optimize(self, iter):

        self._forward()

        if 'GAN' in self.cfg.LOSS_TYPES:

            self.set_requires_grad(self.net, False)
            self.set_requires_grad(self.discriminator, True)
            fake_d = self.result['gen_img']
            real_d = self.target_modal
            # fake_d = self.result_c['feat_gen']
            # real_d = self.result_c['feat_target']

            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                loss_d_fake = self.discriminator(fake_d.detach(), False)
                loss_d_true = self.discriminator(real_d.detach(), True)
            else:
                loss_d_fake = self.discriminator(fake_d.detach(), False).mean()
                loss_d_true = self.discriminator(real_d.detach(), True).mean()

            loss_d = (loss_d_fake + loss_d_true) * 0.5
            self.loss_meters['TRAIN_GAN_D_LOSS'].update(
                loss_d.item(), self.batch_size)

            self.optimizer_d.zero_grad()
            if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED:
                with apex.amp.scale_loss(loss_d,
                                         self.optimizer_d) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_d.backward()
            self.optimizer_d.step()

        loss_g = self._construct_loss(iter)

        if 'GAN' in self.cfg.LOSS_TYPES and self.discriminator is not None:
            self.set_requires_grad(self.discriminator, False)
            self.set_requires_grad(self.net, True)

        self.optimizer.zero_grad()
        if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED:
            with apex.amp.scale_loss(loss_g, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss_g.backward()
        self.optimizer.step()

    def set_criterion(self, cfg):

        if 'CLS' in self.cfg.LOSS_TYPES or self.cfg.EVALUATE:
            criterion_cls = util.CrossEntropyLoss(
                weight=cfg.CLASS_WEIGHTS_TRAIN,
                device=self.device,
                ignore_index=cfg.IGNORE_LABEL)
            self.net.set_cls_criterion(criterion_cls)

        if 'SEMANTIC' in self.cfg.LOSS_TYPES:
            criterion_content = torch.nn.L1Loss()
            content_model = networks.Content_Model(cfg, criterion_content).to(
                self.device)
            self.net.set_content_model(content_model)

        if 'PIX2PIX' in self.cfg.LOSS_TYPES:
            criterion_pix2pix = torch.nn.L1Loss()
            self.net.set_pix2pix_criterion(criterion_pix2pix)

    def set_input(self, data):

        self._source = data['image']
        self.source_modal = self._source.to(self.device)
        self.batch_size = self._source.size()[0]
        if 'label' in data.keys():
            self._label = data['label']
            self.label = torch.LongTensor(self._label).to(self.device)
        else:
            self.label = None

        if self.cfg.TARGET_MODAL:
            if self.cfg.MULTI_SCALE:
                self.target_modal = data[self.cfg.TARGET_MODAL][-1].to(
                    self.device)
            else:
                self.target_modal = data[self.cfg.TARGET_MODAL].to(self.device)
        else:
            self.target_modal = None

        # if self.trans or self.cfg.RESUME:
        #     if not self.cfg.MULTI_SCALE:
        #         self.target_modal = self.target_modal
        #     else:

        # if self.cfg.WHICH_DIRECTION == 'BtoA':
        #     self.source_modal, self.target_modal = self.target_modal, self.source_modal

    def train_parameters(self, cfg):

        assert self.cfg.LOSS_TYPES
        self.set_optimizer(cfg)
        self.set_log_data(cfg)
        self.set_schedulers(cfg)
        self.set_device()

        # self.net = nn.DataParallel(self.net).to(self.device)

        train_iters = 0
        best_result = 0

        if self.cfg.EVALUATE and self.cfg.SLIDE_WINDOWS:
            self.prediction_matrix = torch.zeros(
                self.batch_size_val, self.cfg.NUM_CLASSES,
                self.cfg.BASE_SIZE[0], self.cfg.BASE_SIZE[1]).to(self.device)
            self.count_crop_matrix = torch.zeros(
                self.batch_size_val, 1, self.cfg.BASE_SIZE[0],
                self.cfg.BASE_SIZE[1]).to(self.device)

        if cfg.INFERENCE:
            self.phase = 'test'
            start_time = time.time()
            print('Inferencing model...')
            self.evaluate()
            self.print_evaluate_results()
            save_dir = './images/'
            # np.savetxt(save_dir+'/target.txt',self.target_index_all)
            # np.savetxt(save_dir+'/pred.txt',self.pred_index_all)
            np.savetxt(save_dir + '/class_baseline.txt', self.accuracy_class)
            # self.target_index_all=np.loadtxt(save_dir+'/target.txt')
            # self.pred_index_all=np.loadtxt(save_dir+'/pred.txt')

            # from sklearn.metrics import confusion_matrix
            # cm=confusion_matrix(self.target_index_all,self.pred_index_all)
            util.plot_confusion_matrix(self.target_index_all,
                                       self.pred_index_all,
                                       self.val_loader.dataset.classes)
            print('Evaluation Time: {0} sec'.format(time.time() - start_time))
            # self.write_loss(phase=self.phase)
            return

        if cfg.MULTIPROCESSING_DISTRIBUTED:
            total_epoch = int(cfg.NITER_TOTAL / math.ceil(
                (self.train_image_num /
                 (cfg.BATCH_SIZE_TRAIN * len(cfg.GPU_IDS)))))
        else:
            total_epoch = int(cfg.NITER_TOTAL / math.ceil(
                (self.train_image_num / cfg.BATCH_SIZE_TRAIN)))

        print('total epoch:{0}, total iters:{1}'.format(
            total_epoch, cfg.NITER_TOTAL))

        for epoch in range(cfg.START_EPOCH, total_epoch + 1):

            if train_iters > cfg.NITER_TOTAL:
                break

            if cfg.MULTIPROCESSING_DISTRIBUTED:
                cfg.train_sampler.set_epoch(epoch)

            self.print_lr()

            # current_lr = util.poly_learning_rate(cfg.LR, train_iters, cfg.NITER_TOTAL, power=0.8)

            # if cfg.LR_POLICY != 'plateau':
            #     self.update_learning_rate(step=train_iters)
            # else:
            #     self.update_learning_rate(val=self.loss_meters['VAL_CLS_LOSS'].avg)

            self.fake_image_num = 0

            start_time = time.time()

            self.phase = 'train'
            self.net.train()

            # reset Averagemeters on each epoch
            for key in self.loss_meters:
                self.loss_meters[key].reset()

            iters = 0
            print('gpu_ids:', cfg.GPU_IDS)
            print('# Training images num = {0}'.format(self.train_image_num))
            # batch = tqdm(self.train_loader)
            # for data in batch:
            for data in self.train_loader:
                self.set_input(data)
                train_iters += 1
                iters += 1

                self._optimize(train_iters)
                self.update_learning_rate(step=train_iters)

                # self.val_iou = self.validate(train_iters)
                # self.write_loss(phase=self.phase, global_step=train_iters)

            print('log_path:', cfg.LOG_PATH)
            print('iters in one epoch:', iters)
            self.write_loss(phase=self.phase, global_step=train_iters)
            print('Epoch: {epoch}/{total}'.format(epoch=epoch,
                                                  total=total_epoch))
            util.print_current_errors(
                util.get_current_errors(self.loss_meters, current=False),
                epoch)
            print('Training Time: {0} sec'.format(time.time() - start_time))

            # if cfg.EVALUATE:
            if (epoch % self.cfg.EVALUATE_FREQ == 0 or epoch > total_epoch - 10
                    or epoch == total_epoch) and cfg.EVALUATE:
                print('# Cls val images num = {0}'.format(self.val_image_num))
                self.evaluate()
                self.print_evaluate_results()
                self.write_loss(phase=self.phase, global_step=train_iters)

                # save best model
                if cfg.SAVE_BEST and epoch > total_epoch - 10:
                    # save model
                    for key in self.loss_meters:
                        if 'MEAN' in key and self.loss_meters[key].val > 0:
                            if self.loss_meters[key].val > best_result:
                                best_result = self.loss_meters[key].val
                                model_filename = 'best_{0}.pth'.format(
                                    self.cfg.LOG_NAME)
                                print('best epoch / iters are {0}/{1}'.format(
                                    epoch, iters))
                                self.save_checkpoint(model_filename)
                            print('best {0} is {1}, epoch is {2}, iters {3}'.
                                  format(key, best_result, epoch, iters))

            print('End of iter {0} / {1} \t '
                  'Time Taken: {2} sec'.format(train_iters, cfg.NITER_TOTAL,
                                               time.time() - start_time))
            print('-' * 80)

    def evaluate(self):

        if not self.cfg.SLIDE_WINDOWS:
            self.validate()
        else:
            self.validate_slide_window()

    def save_best(self, best_result, epoch=None, iters=None):

        if self.cfg.TASK_TYPE == 'segmentation':
            result = self.loss_meters['VAL_CLS_MEAN_IOU'].val
        elif self.cfg.TASK_TYPE == 'recognition':
            result = self.loss_meters['VAL_CLS_MEAN_ACC'].val

        is_best = result > best_result
        best_result = max(result, best_result)
        if is_best:
            model_filename = 'best_{0}.pth'.format(self.cfg.LOG_NAME)
            print('best epoch / iters are {0}/{1}'.format(epoch, iters))
            self.save_checkpoint(model_filename)
            print('best miou is {0}, epoch is {1}, iters {2}'.format(
                best_result, epoch, iters))

    def print_evaluate_results(self):
        if self.cfg.TASK_TYPE == 'segmentation':
            print('MIOU: {miou}, mAcc: {macc}, acc: {acc}'.format(
                miou=self.loss_meters['VAL_CLS_MEAN_IOU'].val * 100,
                macc=self.loss_meters['VAL_CLS_MEAN_ACC'].val * 100,
                acc=self.loss_meters['VAL_CLS_ACC'].val * 100))

        elif self.cfg.TASK_TYPE == 'recognition':
            print('Mean Acc Top1 <{mean_acc:.3f}> '.format(
                mean_acc=self.loss_meters['VAL_CLS_MEAN_ACC'].val * 100))
        elif self.cfg.TASK_TYPE == 'infomax':
            print('Mean Acc Top1 MLP: <{mean_acc:.3f}> '.format(
                mean_acc=self.loss_meters['VAL_CLS_MEAN_ACC_MLP'].val * 100))

    def _forward(self, cal_loss=True):

        if self.cfg.USE_FAKE_DATA:
            with torch.no_grad():
                result_sample = self.sample_model(source=self.source_modal,
                                                  target=None,
                                                  label=None,
                                                  phase=self.phase,
                                                  cal_loss=False)

            fake_imgs = result_sample['gen_img']
            input_num = len(fake_imgs)
            indexes = [i for i in range(input_num)]
            random_index = random.sample(
                indexes, int(len(fake_imgs) * self.cfg.FAKE_DATA_RATE))

            for i in random_index:
                self.source_modal[i, :] = fake_imgs.data[i, :]

        self.result = self.net(source=self.source_modal,
                               target=self.target_modal,
                               label=self.label,
                               phase=self.phase,
                               cal_loss=cal_loss)

    def _construct_loss(self, iter):

        loss_total = torch.zeros(1).to(self.device)

        if 'CLS' in self.cfg.LOSS_TYPES:

            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                cls_loss = self.result['loss_cls'] * self.cfg.ALPHA_CLS
                loss_total += cls_loss

                dist.all_reduce(cls_loss)

                if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                    cls_loss_compl = self.result[
                        'loss_cls_compl'] * self.cfg.ALPHA_CLS
                    loss_total += cls_loss_compl

                    # cls_loss_fuse = self.result['loss_cls_fuse'] * self.cfg.ALPHA_CLS
                    # loss_total += cls_loss_fuse

                    dist.all_reduce(cls_loss_compl)

            else:
                cls_loss = self.result['loss_cls'].mean() * self.cfg.ALPHA_CLS
                loss_total += cls_loss

                if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                    cls_loss_compl = self.result['loss_cls_compl'].mean(
                    ) * self.cfg.ALPHA_CLS
                    loss_total += cls_loss_compl
                    # cls_loss_fuse = self.result['loss_cls_fuse'].mean() * self.cfg.ALPHA_CLS
                    # loss_total += cls_loss_fuse

            self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss.item(),
                                                      self.batch_size)
            if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                self.loss_meters['TRAIN_CLS_LOSS_COMPL'].update(
                    cls_loss_compl.item(), self.batch_size)
                # self.loss_meters['TRAIN_CLS_LOSS_FUSE'].update(cls_loss_fuse.item(), self.batch_size)

        # ) content supervised
        if 'SEMANTIC' in self.cfg.LOSS_TYPES:

            if self.cfg.MULTI_MODAL:
                self.gen = [self.result['gen_img_1'], self.result['gen_img_2']]
            else:
                self.gen = self.result['gen_img']

            decay_coef = 1
            # decay_coef = (iters / self.cfg.NITER_TOTAL)  # small to big
            # decay_coef = max(0, (self.cfg.NITER_TOTAL - iter) / self.cfg.NITER_TOTAL) # big to small
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                content_loss = self.result[
                    'loss_content'] * self.cfg.ALPHA_CONTENT * decay_coef
                loss_total += content_loss

                dist.all_reduce(content_loss)
                # content_loss = content_loss.detach() / self.batch_size

            else:
                content_loss = self.result['loss_content'].mean(
                ) * self.cfg.ALPHA_CONTENT * decay_coef
                loss_total += content_loss

            self.loss_meters['TRAIN_SEMANTIC_LOSS'].update(
                content_loss.item(), self.batch_size)

        if 'PIX2PIX' in self.cfg.LOSS_TYPES:

            if self.cfg.MULTI_MODAL:
                self.gen = [self.result['gen_img_1'], self.result['gen_img_2']]
            else:
                self.gen = self.result['gen_img']

            decay_coef = 1
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                pix2pix_loss = self.result[
                    'loss_pix2pix'] * self.cfg.ALPHA_PIX2PIX * decay_coef
                loss_total += pix2pix_loss
            else:
                pix2pix_loss = self.result['loss_pix2pix'].mean(
                ) * self.cfg.ALPHA_PIX2PIX * decay_coef
                loss_total += pix2pix_loss

            self.loss_meters['TRAIN_PIX2PIX_LOSS'].update(
                pix2pix_loss, self.batch_size)

        if 'GAN' in self.cfg.LOSS_TYPES:

            real_g = self.result['gen_img']
            # real_g = torch.cat((self.result['gen_img'], self.source_modal), 1)
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                loss_gan_g = self.discriminator(real_g,
                                                True) * self.cfg.ALPHA_GAN
            else:
                loss_gan_g = self.discriminator(
                    real_g, True).mean() * self.cfg.ALPHA_GAN
            self.loss_meters['TRAIN_GAN_G_LOSS'].update(
                loss_gan_g.item(), self.batch_size)

            loss_total += loss_gan_g

        return loss_total

    def set_log_data(self, cfg):

        self.loss_meters = defaultdict()
        self.log_keys = [
            'TRAIN_GAN_G_LOSS',
            'TRAIN_GAN_D_LOSS',
            'TRAIN_SEMANTIC_LOSS',  # semantic
            'TRAIN_PIX2PIX_LOSS',
            'TRAIN_CLS_ACC',
            'VAL_CLS_ACC',  # classification
            'TRAIN_CLS_LOSS',
            'TRAIN_CLS_MEAN_IOU',
            'VAL_CLS_LOSS',
            'VAL_CLS_MEAN_IOU',
            'VAL_CLS_MEAN_ACC',
            'INTERSECTION',
            'UNION',
            'LABEL',
            'TRAIN_CLS_LOSS_COMPL',
            'TRAIN_CLS_LOSS_FUSE'
        ]
        for item in self.log_keys:
            self.loss_meters[item] = AverageMeter()

    def set_optimizer(self, cfg):

        self.optimizers = []

        # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=cfg.LR, betas=(0.5, 0.999))

        if self.params_list:
            self.optimizer = torch.optim.Adam(self.params_list,
                                              lr=cfg.LR,
                                              betas=(0.5, 0.999))
        else:
            self.optimizer = torch.optim.Adam(self.net.parameters(),
                                              lr=cfg.LR,
                                              betas=(0.5, 0.999))
            # self.optimizer = torch.optim.SGD(self.net.parameters(), lr=cfg.LR, momentum=cfg.MOMENTUM, weight_decay=cfg.WEIGHT_DECAY)

        if cfg.MULTIPROCESSING_DISTRIBUTED:
            if cfg.USE_APEX:
                self.net, self.optimizer = apex.amp.initialize(
                    self.net.cuda(), self.optimizer, opt_level=cfg.opt_level)
                self.net = DDP(self.net)

            else:
                self.net = torch.nn.parallel.DistributedDataParallel(
                    self.net.cuda(), device_ids=[cfg.gpu])

        self.optimizers.append(self.optimizer)

        if 'GAN' in self.cfg.LOSS_TYPES:
            self.optimizer_d = torch.optim.SGD(self.discriminator.parameters(),
                                               lr=cfg.LR,
                                               momentum=0.9,
                                               weight_decay=0.0005)

            if cfg.MULTIPROCESSING_DISTRIBUTED:
                if cfg.USE_APEX:
                    self.discriminator, self.optimizer_d = apex.amp.initialize(
                        self.discriminator.cuda(),
                        self.optimizer_d,
                        opt_level=cfg.opt_level)
                    self.discriminator = DDP(self.discriminator)
                else:
                    self.discriminator = torch.nn.parallel.DistributedDataParallel(
                        self.discriminator.cuda(), device_ids=[cfg.gpu])

            self.optimizers.append(self.optimizer_d)

    def validate_slide_window(self):

        self.net.eval()
        self.phase = 'test'

        intersection_meter = self.loss_meters['INTERSECTION']
        union_meter = self.loss_meters['UNION']
        target_meter = self.loss_meters['LABEL']

        print('testing with sliding windows...')
        num_images = 0

        # batch = tqdm(self.val_loader)
        # for data in batch:
        for data in self.val_loader:
            self.set_input(data)
            num_images += self.batch_size
            pred = util.slide_cal(model=self.net,
                                  image=self.source_modal,
                                  crop_size=self.cfg.FINE_SIZE,
                                  prediction_matrix=self.prediction_matrix[
                                      0:self.batch_size, :, :, :],
                                  count_crop_matrix=self.count_crop_matrix[
                                      0:self.batch_size, :, :, :])

            self.pred = pred.data.max(1)[1]
            intersection, union, label = util.intersectionAndUnionGPU(
                self.pred, self.label, self.cfg.NUM_CLASSES)
            if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                dist.all_reduce(intersection), dist.all_reduce(
                    union), dist.all_reduce(label)
            intersection, union, label = intersection.cpu().numpy(), union.cpu(
            ).numpy(), label.cpu().numpy()

            intersection_meter.update(intersection, self.batch_size)
            union_meter.update(union, self.batch_size)
            target_meter.update(label, self.batch_size)

        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
        accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
        mIoU = np.mean(iou_class)
        mAcc = np.mean(accuracy_class)
        allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

        self.loss_meters['VAL_CLS_ACC'].update(allAcc)
        self.loss_meters['VAL_CLS_MEAN_ACC'].update(mAcc)
        self.loss_meters['VAL_CLS_MEAN_IOU'].update(mIoU)

    def validate(self):

        self.phase = 'test'

        # switch to evaluate mode
        self.net.eval()

        intersection_meter = self.loss_meters['INTERSECTION']
        union_meter = self.loss_meters['UNION']
        target_meter = self.loss_meters['LABEL']

        if self.cfg.USE_FAKE_DATA or self.cfg.INFERENCE:
            self.pred_index_all = []
            self.target_index_all = []

        with torch.no_grad():

            # batch_index = int(self.val_image_num / cfg.BATCH_SIZE)
            # random_id = random.randint(0, batch_index)

            # batch = tqdm(self.val_loader)
            # for data in batch:
            for i, data in enumerate(self.val_loader):
                self.set_input(data)

                self._forward(cal_loss=False)
                if self.cfg.INFERENCE:
                    self._process_fc()
                self.pred = self.result['cls'].data.max(1)[1]
                intersection, union, label = util.intersectionAndUnionGPU(
                    self.pred, self.label, self.cfg.NUM_CLASSES)
                if self.cfg.MULTIPROCESSING_DISTRIBUTED:
                    dist.all_reduce(intersection), dist.all_reduce(
                        union), dist.all_reduce(label)

                intersection, union, label = intersection.cpu().numpy(
                ), union.cpu().numpy(), label.cpu().numpy()

                intersection_meter.update(intersection, self.batch_size)
                union_meter.update(union, self.batch_size)
                target_meter.update(label, self.batch_size)

        # Mean ACC
        # self._cal_mean_acc(self.cfg,self.val_loader)
        accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
        self.accuracy_class = accuracy_class
        mAcc = np.mean(accuracy_class)
        allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

        self.loss_meters['VAL_CLS_ACC'].update(allAcc)
        self.loss_meters['VAL_CLS_MEAN_ACC'].update(mAcc)

        if self.cfg.TASK_TYPE == 'segmentation':
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            mIoU = np.mean(iou_class)
            self.loss_meters['VAL_CLS_MEAN_IOU'].update(mIoU)

    def _process_fc(self):

        # dist.all_reduce(self.result['cls'])
        _, index = self.result['cls'].data.topk(1, 1, largest=True)

        self.pred_index_all.extend(list(index.cpu().numpy()))
        self.target_index_all.extend(list(self._label.numpy()))

    def _cal_mean_acc(self, cfg, data_loader):

        mean_acc = util.mean_acc(np.array(self.target_index_all),
                                 np.array(self.pred_index_all),
                                 cfg.NUM_CLASSES, data_loader.dataset.classes)
        return mean_acc

    def write_loss(self, phase, global_step=1):

        loss_types = self.cfg.LOSS_TYPES
        task = self.cfg.TASK_TYPE

        if self.phase == 'train':
            label_show = self.label.data.cpu().numpy()
        else:
            label_show = np.uint8(self.label.data.cpu())

        source_modal_show = self.source_modal
        target_modal_show = self.target_modal

        if phase == 'train':

            self.writer.add_image(task + '/Train_image',
                                  torchvision.utils.make_grid(
                                      source_modal_show[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            self.writer.add_scalar(task + '/LR',
                                   self.optimizer.param_groups[0]['lr'],
                                   global_step=global_step)

            if 'CLS' in loss_types:
                self.writer.add_scalar(task + '/TRAIN_CLS_LOSS',
                                       self.loss_meters['TRAIN_CLS_LOSS'].avg,
                                       global_step=global_step)
                if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                    self.writer.add_scalar(
                        task + '/TRAIN_CLS_LOSS_COMPL',
                        self.loss_meters['TRAIN_CLS_LOSS_COMPL'].avg,
                        global_step=global_step)
                    self.writer.add_image(
                        task + '/Compl_image',
                        torchvision.utils.make_grid(
                            self.result['compl_source'][:6].clone().cpu().data,
                            3,
                            normalize=True),
                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_ACC', self.loss_meters['TRAIN_CLS_ACC'].avg*100.0,
                #                        global_step=global_step)
                # self.writer.add_scalar('TRAIN_CLS_MEAN_IOU', float(self.train_iou.mean())*100.0,
                #                        global_step=global_step)

            if self.trans and not self.cfg.MULTI_MODAL:

                if 'SEMANTIC' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        task + '/TRAIN_SEMANTIC_LOSS',
                        self.loss_meters['TRAIN_SEMANTIC_LOSS'].avg,
                        global_step=global_step)
                if 'PIX2PIX' in self.cfg.LOSS_TYPES:
                    self.writer.add_scalar(
                        task + '/TRAIN_PIX2PIX_LOSS',
                        self.loss_meters['TRAIN_PIX2PIX_LOSS'].avg,
                        global_step=global_step)

                self.writer.add_image(task + '/Train_gen',
                                      torchvision.utils.make_grid(
                                          self.gen.data[:6].clone().cpu().data,
                                          3,
                                          normalize=True),
                                      global_step=global_step)
                self.writer.add_image(
                    task + '/Train_image',
                    torchvision.utils.make_grid(
                        source_modal_show[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)
                # if isinstance(self.target_modal, list):
                #     for i, (gen, target) in enumerate(zip(self.gen, self.target_modal)):
                #         self.writer.add_image('Seg/2_Train_Gen_' + str(self.cfg.FINE_SIZE / pow(2, i)),
                #                               torchvision.utils.make_grid(gen[:6].clone().cpu().data, 3,
                #                                                           normalize=True),
                #                               global_step=global_step)
                #         self.writer.add_image('Seg/3_Train_Target_' + str(self.cfg.FINE_SIZE / pow(2, i)),
                #                               torchvision.utils.make_grid(target[:6].clone().cpu().data, 3,
                #                                                           normalize=True),
                #                               global_step=global_step)
                # else:
                self.writer.add_image(
                    task + '/Train_target',
                    torchvision.utils.make_grid(
                        target_modal_show[:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)

            if 'CLS' in loss_types and self.cfg.TASK_TYPE == 'segmentation':
                train_pred = self.result['cls'].data.max(1)[1].cpu().numpy()
                self.writer.add_image(
                    task + '/Train_predicted',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(train_pred[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)
                self.writer.add_image(
                    task + '/Train_label',
                    torchvision.utils.make_grid(torch.from_numpy(
                        util.color_label(label_show[:6],
                                         ignore=self.cfg.IGNORE_LABEL,
                                         dataset=self.cfg.DATASET)),
                                                3,
                                                normalize=True,
                                                range=(0, 255)),
                    global_step=global_step)

        elif phase == 'test':

            self.writer.add_image(task + '/Val_image',
                                  torchvision.utils.make_grid(
                                      source_modal_show[:6].clone().cpu().data,
                                      3,
                                      normalize=True),
                                  global_step=global_step)
            # self.writer.add_image('Seg/Val_image',
            #                       torchvision.utils.make_grid(source_modal_show[:6].clone().cpu().data, 3,
            #                                                   normalize=True), global_step=global_step)
            #
            # self.writer.add_image('Seg/Val_predicted',
            #                       torchvision.utils.make_grid(
            #                           torch.from_numpy(util.color_label(self.pred[:6], ignore=self.cfg.IGNORE_LABEL,
            #                                                             dataset=self.cfg.DATASET)), 3,
            #                           normalize=True, range=(0, 255)), global_step=global_step)
            # self.writer.add_image('Seg/Val_label',
            #                       torchvision.utils.make_grid(torch.from_numpy(
            #                           util.color_label(label_show[:6], ignore=self.cfg.IGNORE_LABEL,
            #                                            dataset=self.cfg.DATASET)),
            #                           3, normalize=True, range=(0, 255)),
            #                       global_step=global_step)

            if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA:
                self.writer.add_image(
                    task + '/Compl_image',
                    torchvision.utils.make_grid(
                        self.result['compl_source'][:6].clone().cpu().data,
                        3,
                        normalize=True),
                    global_step=global_step)

            self.writer.add_scalar(task + '/VAL_CLS_ACC',
                                   self.loss_meters['VAL_CLS_ACC'].val * 100.0,
                                   global_step=global_step)
            self.writer.add_scalar(task + '/VAL_CLS_MEAN_ACC',
                                   self.loss_meters['VAL_CLS_MEAN_ACC'].val *
                                   100.0,
                                   global_step=global_step)
            if task == 'segmentation':
                self.writer.add_scalar(
                    task + '/VAL_CLS_MEAN_IOU',
                    self.loss_meters['VAL_CLS_MEAN_IOU'].val * 100.0,
                    global_step=global_step)