def _make_parallel(runner, net): if runner.configer.get('network.distributed', default=False): #print('n1') from apex.parallel import DistributedDataParallel #print('n2') if runner.configer.get('network.syncbn', default=False): Log.info('Converting syncbn model...') from apex.parallel import convert_syncbn_model net = convert_syncbn_model(net) torch.cuda.set_device(runner.configer.get('local_rank')) torch.distributed.init_process_group(backend='nccl', init_method='env://') net = DistributedDataParallel(net.cuda(), delay_allreduce=True) return net net = net.to( torch.device( 'cpu' if runner.configer.get('gpu') is None else 'cuda')) if len(runner.configer.get('gpu')) > 1: from exts.tools.parallel.data_parallel import ParallelModel return ParallelModel(net, gather_=runner.configer.get( 'network', 'gather')) return net
def _make_parallel(runner, net): if runner.configer.get('network.distributed', default=False): from apex.parallel import DistributedDataParallel torch.cuda.set_device(runner.configer.get('local_rank')) torch.distributed.init_process_group(backend='nccl', init_method='env://') net = DistributedDataParallel(net.cuda(), delay_allreduce=True) return net else: net = net.to( torch.device( 'cpu' if runner.configer.get('gpu') is None else 'cuda')) from exts.tools.parallel.data_parallel import ParallelModel return ParallelModel(net, gather_=runner.configer.get( 'network', 'gather'))
def make_parallel(runner, net, optimizer): if runner.configer.get('distributed', default=False): from apex.parallel import DistributedDataParallel if runner.configer.get('network.syncbn', default=False): Log.info('Converting syncbn model...') from apex.parallel import convert_syncbn_model net = convert_syncbn_model(net) torch.cuda.set_device(runner.configer.get('local_rank')) torch.distributed.init_process_group(backend='nccl', init_method='env://') if runner.configer.get('dtype') == 'fp16': from apex import amp net, optimizer = amp.initialize(net.cuda(), optimizer, opt_level="O1") net = DistributedDataParallel(net, delay_allreduce=True) else: assert runner.configer.get('dtype') == 'none' net = DistributedDataParallel(net.cuda(), delay_allreduce=True) return net, optimizer net = net.to(torch.device('cpu' if runner.configer.get('gpu') is None else 'cuda')) if len(runner.configer.get('gpu')) > 1: from lib.utils.parallel.data_parallel import DataParallelModel return DataParallelModel(net, gather_=runner.configer.get('network', 'gather')), optimizer return net, optimizer
class nnUNetTrainerV2_DDP(nnUNetTrainerV2): def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.init_args = (plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, distribute_batch_size, fp16) self.distribute_batch_size = distribute_batch_size np.random.seed(local_rank) torch.manual_seed(local_rank) torch.cuda.manual_seed_all(local_rank) self.local_rank = local_rank torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl', init_method='env://') self.val_loss_ma_alpha = 0.95 self.val_loss_MA = None self.loss = None self.ce_loss = CrossentropyND() self.global_batch_size = None # we need to know this to properly steer oversample def set_batch_size_and_oversample(self): batch_sizes = [] oversample_percents = [] world_size = dist.get_world_size() my_rank = dist.get_rank() if self.distribute_batch_size: self.global_batch_size = self.batch_size else: self.global_batch_size = self.batch_size * world_size batch_size_per_GPU = np.ceil(self.batch_size / world_size).astype(int) for rank in range(world_size): if self.distribute_batch_size: if (rank + 1) * batch_size_per_GPU > self.batch_size: batch_size = batch_size_per_GPU - ( (rank + 1) * batch_size_per_GPU - self.batch_size) else: batch_size = batch_size_per_GPU else: batch_size = self.batch_size batch_sizes.append(batch_size) sample_id_low = 0 if len(batch_sizes) == 0 else np.sum( batch_sizes[:-1]) sample_id_high = np.sum(batch_sizes) if sample_id_high / self.global_batch_size < ( 1 - self.oversample_foreground_percent): oversample_percents.append(0.0) elif sample_id_low / self.global_batch_size > ( 1 - self.oversample_foreground_percent): oversample_percents.append(1.0) else: percent_covered_by_this_rank = sample_id_high / self.global_batch_size - sample_id_low / self.global_batch_size oversample_percent_here = 1 - ( ((1 - self.oversample_foreground_percent) - sample_id_low / self.global_batch_size) / percent_covered_by_this_rank) oversample_percents.append(oversample_percent_here) print("worker", my_rank, "oversample", oversample_percents[my_rank]) print("worker", my_rank, "batch_size", batch_sizes[my_rank]) self.batch_size = batch_sizes[my_rank] self.oversample_foreground_percent = oversample_percents[my_rank] def save_checkpoint(self, fname, save_optimizer=True): if self.local_rank == 0: super().save_checkpoint(fname, save_optimizer) def plot_progress(self): if self.local_rank == 0: super().plot_progress() def print_to_log_file(self, *args, also_print_to_console=True): if self.local_rank == 0: super().print_to_log_file( *args, also_print_to_console=also_print_to_console) def initialize_network(self): """ This is specific to the U-Net and must be adapted for other network architectures :return: """ self.print_to_log_file(self.net_num_pool_op_kernel_sizes) self.print_to_log_file(self.net_conv_kernel_sizes) if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet( self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) self.network.cuda() self.network.inference_apply_nonlin = softmax_helper def process_plans(self, plans): super().process_plans(plans) self.set_batch_size_and_oversample() def initialize(self, training=True, force_load_plans=False): """ For prediction of test cases just set training=False, this will prevent loading of training data and training batchgenerator initialization :param training: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() self.folder_with_preprocessed_data = join( self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: if self.local_rank == 0: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: # we need to wait until worker 0 has finished unpacking npz_files = subfiles( self.folder_with_preprocessed_data, suffix=".npz", join=False) case_ids = [i[:-4] for i in npz_files] all_present = all([ isfile( join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids ]) while not all_present: print("worker", self.local_rank, "is waiting for unpacking") sleep(3) all_present = all([ isfile( join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids ]) # there is some slight chance that there may arise some error because dataloader are loading a file # that is still being written by worker 0. We ignore this for now an address it only if it becomes # relevant # (this can occur because while worker 0 writes the file is technically present so the other workers # will proceed and eventually try to read it) else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") # setting weights for deep supervision losses net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2**i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([ True if i < net_numpool - 1 else False for i in range(net_numpool) ]) weights[~mask] = 0 weights = weights / weights.sum() self.ds_loss_weights = weights seeds_train = np.random.random_integers( 0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers( 0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) print("seeds train", seeds_train) print("seeds_val", seeds_val) self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, seeds_train=seeds_train, seeds_val=seeds_val) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() self._maybe_init_amp() self.network = DDP(self.network) else: self.print_to_log_file( 'self.was_initialized is True, not running self.initialize again' ) self.was_initialized = True def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) data = to_cuda(data, gpu_id=None) target = to_cuda(target, gpu_id=None) self.optimizer.zero_grad() output = self.network(data) del data total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = softmax_helper(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i]) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) if run_online_evaluation: with torch.no_grad(): num_classes = output[0].shape[1] output_seg = output[0].argmax(1) target = target[0][:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target, # axes, None) # print_if_rank0("before allgather", tp_hard.shape) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] tp_hard = awesome_allgather_function.apply(tp_hard) fp_hard = awesome_allgather_function.apply(fp_hard) fn_hard = awesome_allgather_function.apply(fn_hard) # print_if_rank0("after allgather", tp_hard.shape) # print_if_rank0("after sum", tp_hard.shape) self.run_online_evaluation( tp_hard.detach().cpu().numpy().sum(0), fp_hard.detach().cpu().numpy().sum(0), fn_hard.detach().cpu().numpy().sum(0)) del target if do_backprop: if not self.fp16 or amp is None: total_loss.backward() else: with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: scaled_loss.backward() _ = clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return total_loss.detach().cpu().numpy() def run_online_evaluation(self, tp, fp, fn): self.online_eval_foreground_dc.append( list((2 * tp) / (2 * tp + fp + fn + 1e-8))) self.online_eval_tp.append(list(tp)) self.online_eval_fp.append(list(fp)) self.online_eval_fn.append(list(fn)) def run_training(self): """ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first continued epoch with self.initial_lr we also need to make sure deep supervision in the network is enabled for training, thus the wrapper :return: """ self.maybe_update_lr( self.epoch ) # if we dont overwrite epoch then self.epoch+1 is used which is not what we # want at the start of the training if isinstance(self.network, DDP): net = self.network.module else: net = self.network ds = net.do_ds net.do_ds = True ret = nnUNetTrainer.run_training(self) net.do_ds = ds return ret def validate(self, do_mirroring: bool = True, use_train_mode: bool = False, tiled: bool = True, step: int = 2, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, force_separate_z: bool = None, interpolation_order: int = 3, interpolation_order_z=0): if self.local_rank == 0: if isinstance(self.network, DDP): net = self.network.module else: net = self.network ds = net.do_ds net.do_ds = False ret = nnUNetTrainer.validate( self, do_mirroring, use_train_mode, tiled, step, save_softmax, use_gaussian, overwrite, validation_folder_name, debug, all_in_gpu, force_separate_z=force_separate_z, interpolation_order=interpolation_order, interpolation_order_z=interpolation_order_z) net.do_ds = ds return ret def predict_preprocessed_data_return_softmax(self, data, do_mirroring, num_repeats, use_train_mode, batch_size, mirror_axes, tiled, tile_in_z, step, min_size, use_gaussian, all_in_gpu=False): """ Don't use this. If you need softmax output, use preprocess_predict_nifti and set softmax_output_file. :param data: :param do_mirroring: :param num_repeats: :param use_train_mode: :param batch_size: :param mirror_axes: :param tiled: :param tile_in_z: :param step: :param min_size: :param use_gaussian: :param use_temporal: :return: """ valid = list((SegmentationNetwork, nn.DataParallel, DDP)) assert isinstance(self.network, tuple(valid)) if isinstance(self.network, DDP): net = self.network.module else: net = self.network ds = net.do_ds net.do_ds = False ret = net.predict_3D(data, do_mirroring, num_repeats, use_train_mode, batch_size, mirror_axes, tiled, tile_in_z, step, min_size, use_gaussian=use_gaussian, pad_border_mode=self.inference_pad_border_mode, pad_kwargs=self.inference_pad_kwargs, all_in_gpu=all_in_gpu)[2] net.do_ds = ds return ret def load_checkpoint_ram(self, saved_model, train=True): """ used for if the checkpoint is already in ram :param saved_model: :param train: :return: """ if not self.was_initialized: self.initialize(train) new_state_dict = OrderedDict() curr_state_dict_keys = list(self.network.state_dict().keys()) # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not # match. Use heuristic to make it match for k, value in saved_model['state_dict'].items(): key = k if key not in curr_state_dict_keys: print("duh") key = key[7:] new_state_dict[key] = value # if we are fp16, then we need to reinitialize the network and the optimizer. Otherwise amp will throw an error if self.fp16: self.network, self.optimizer, self.lr_scheduler = None, None, None self.initialize_network() self.initialize_optimizer_and_scheduler() # we need to reinitialize DDP here self.network = DDP(self.network) self.network.load_state_dict(new_state_dict) self.epoch = saved_model['epoch'] if train: optimizer_state_dict = saved_model['optimizer_state_dict'] if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) if self.lr_scheduler is not None and hasattr( self.lr_scheduler, 'load_state_dict' ) and saved_model['lr_scheduler_state_dict'] is not None: self.lr_scheduler.load_state_dict( saved_model['lr_scheduler_state_dict']) if issubclass(self.lr_scheduler.__class__, _LRScheduler): self.lr_scheduler.step(self.epoch) self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = saved_model[ 'plot_stuff'] # after the training is done, the epoch is incremented one more time in my old code. This results in # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here if self.epoch != len(self.all_tr_losses): self.print_to_log_file( "WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " "due to an old bug and should only appear when you are loading old models. New " "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)" ) self.epoch = len(self.all_tr_losses) self.all_tr_losses = self.all_tr_losses[:self.epoch] self.all_val_losses = self.all_val_losses[:self.epoch] self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self. epoch] self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] self.amp_initialized = False self._maybe_init_amp()
def main(model_name, mode, root, val_split, ckpt, batch_per_gpu): num_gpus = MPI.COMM_WORLD.Get_size() distributed = False if num_gpus > 1: distributed = True local_rank = MPI.COMM_WORLD.Get_rank() % torch.cuda.device_count() if distributed: torch.cuda.set_device(local_rank) host = os.environ["MASTER_ADDR"] if "MASTER_ADDR" in os.environ else "127.0.0.1" torch.distributed.init_process_group( backend="nccl", init_method='tcp://{}:12345'.format(host), rank=MPI.COMM_WORLD.Get_rank(), world_size=MPI.COMM_WORLD.Get_size() ) synchronize() val_dataloader = make_dataloader(root, val_split, mode, model_name, seq_len=16, #64, overlap=8, #32, phase='val', max_iters=None, batch_per_gpu=batch_per_gpu, num_workers=16, shuffle=False, distributed=distributed, with_normal=False) if model_name == 'i3d': if mode == 'flow': model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=2, dropout_keep_prob=0.5) else: model = InceptionI3d(val_dataloader.dataset.num_classes, in_channels=3, dropout_keep_prob=0.5) model.replace_logits(val_dataloader.dataset.num_classes) elif model_name == 'r3d_18': model = r3d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes) elif model_name == 'mc3_18': model = mc3_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes) elif model_name == 'r2plus1d_18': model = r2plus1d_18(pretrained=False, num_classes=val_dataloader.dataset.num_classes) elif model_name == 'c3d': model = C3D(pretrained=False, num_classes=val_dataloader.dataset.num_classes) else: raise NameError('unknown model name:{}'.format(model_name)) # pdb.set_trace() for param in model.parameters(): pass device = torch.device('cuda') model.to(device) if distributed: model = apex.parallel.convert_syncbn_model(model) model = DDP(model.cuda(), delay_allreduce=True)
class Solver(object): def __init__(self): """ :param config: easydict """ self.version = __version__ # logging.info("PyTorch Version {}, Solver Version {}".format(torch.__version__, self.version)) self.distributed = False self.world_size = 1 self.local_rank = 0 self.epoch = 0 self.iteration = 0 self.config = None self.model, self.optimizer, self.lr_policy = None, None, None self.step_decay = 1 if 'WORLD_SIZE' in os.environ: self.world_size = int(os.environ['WORLD_SIZE']) self.distributed = self.world_size > 1 or torch.cuda.device_count() > 1 if self.distributed: dist.init_process_group(backend="nccl", init_method='env://') self.local_rank = dist.get_rank() torch.cuda.set_device(self.local_rank) logging.info('[distributed mode] world size: {}, local rank: {}.'.format(self.world_size, self.local_rank)) else: logging.info('[Single GPU mode]') def build_environ(self): if self.config['environ']['deterministic']: cudnn.benchmark = False cudnn.deterministic = True torch.set_printoptions(precision=10) else: cudnn.benchmark = True if self.config['apex']: assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." # set random seed torch.manual_seed(self.config['environ']['seed']) if torch.cuda.is_available(): torch.cuda.manual_seed(self.config['environ']['seed']) np.random.seed(self.config['environ']['seed']) random.seed(self.config['environ']['seed']) def init_from_scratch(self, config): t_start = time.time() self.config = config self.build_environ() # model and optimizer self.model = _get_model(self.config) model_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = _get_optimizer(config['solver']['optimizer'], model_params=model_params) self.lr_policy = _get_lr_policy(config['solver']['lr_policy'], optimizer=self.optimizer) self.step_decay = config['solver']['step_decay'] if config['model'].get('pretrained_model') is not None: logging.info('loadding pretrained model from {}.'.format(config['model']['pretrained_model'])) load_model(self.model, config['model']['pretrained_model'], distributed=False) self.model.cuda(self.local_rank) if self.distributed: self.model = convert_syncbn_model(self.model) if self.config['apex']['amp_used']: # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. logging.info("Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}.". format(self.config['apex']['opt_level'], self.config['apex']['keep_batchnorm_fp32'], self.config['apex']['loss_scale'])) self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self.config['apex']['opt_level'], keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"], loss_scale=self.config['apex']["loss_scale"]) if self.distributed: self.model = DistributedDataParallel(self.model) t_end = time.time() logging.info("Init trainer from scratch, Time usage: IO: {}".format(t_end - t_start)) def init_from_checkpoint(self, continue_state_object): t_start = time.time() self.config = continue_state_object['config'] self.build_environ() self.model = _get_model(self.config) model_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = _get_optimizer(self.config['solver']['optimizer'], model_params=model_params) self.lr_policy = _get_lr_policy(self.config['solver']['lr_policy'], optimizer=self.optimizer) load_model(self.model, continue_state_object['model'], distributed=False) self.model.cuda(self.local_rank) if self.distributed: self.model = convert_syncbn_model(self.model) if self.config['apex']['amp_used']: # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. logging.info("Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}.". format(self.config['apex']['opt_level'], self.config['apex']['keep_batchnorm_fp32'], self.config['apex']['loss_scale'])) self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=self.config['apex']['opt_level'], keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"], loss_scale=self.config['apex']["loss_scale"]) amp.load_state_dict(continue_state_object['amp']) if self.distributed: self.model = DistributedDataParallel(self.model) self.optimizer.load_state_dict(continue_state_object['optimizer']) self.lr_policy.load_state_dict(continue_state_object['lr_policy']) self.step_decay = self.config['solver']['step_decay'] self.epoch = continue_state_object['epoch'] self.iteration = continue_state_object["iteration"] del continue_state_object t_end = time.time() logging.info("Init trainer from checkpoint, Time usage: IO: {}".format(t_end - t_start)) def step(self, **kwargs): """ :param kwargs: :return: """ self.iteration += 1 loss = self.model(**kwargs) loss /= self.step_decay # backward if self.distributed and self.config['apex']['amp_used']: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if self.iteration % self.step_decay == 0: self.optimizer.step() self.optimizer.zero_grad() if self.distributed: reduced_loss = reduce_tensor(loss.data, self.world_size) else: reduced_loss = loss.data return reduced_loss def step_no_grad(self, **kwargs): with torch.no_grad(): out = self.model(**kwargs) return out def before_epoch(self, epoch): self.iteration = 0 self.epoch = epoch self.model.train() self.synchronize() torch.cuda.empty_cache() self.lr_policy.step(epoch) def after_epoch(self, epoch): self.model.eval() self.synchronize() torch.cuda.empty_cache() def synchronize(self): synchronize() def save_checkpoint(self, path): if self.local_rank == 0: # logging.info("Saving checkpoint to file {}".format(path)) t_start = time.time() state_dict = {} from collections import OrderedDict new_state_dict = OrderedDict() for k, v in self.model.state_dict().items(): key = k if k.split('.')[0] == 'module': key = k[7:] new_state_dict[key] = v if self.config['apex']['amp_used']: state_dict['amp'] = amp.state_dict() state_dict['config'] = self.config state_dict['model'] = new_state_dict state_dict['optimizer'] = self.optimizer.state_dict() state_dict['lr_policy'] = self.lr_policy.state_dict() state_dict['epoch'] = self.epoch state_dict['iteration'] = self.iteration t_iobegin = time.time() torch.save(state_dict, path) del state_dict del new_state_dict t_end = time.time() logging.info( "Save checkpoint to file {}, " "Time usage:\n\tprepare snapshot: {}, IO: {}".format( path, t_iobegin - t_start, t_end - t_iobegin)) def save_images(self, filenames, image): raise NotImplementedError def copy_config(self, snapshot_dir, config_file): ensure_dir(snapshot_dir) assert osp.exists(config_file), "config file is not existed." new_file_name = osp.join(snapshot_dir, 'config.json') shutil.copy(config_file, new_file_name) def __enter__(self): return self def __exit__(self, type, value, tb): torch.cuda.empty_cache() if type is not None: logging.warning( "A exception occurred during Engine initialization, " "give up pspnet_ade process") return False
class RunManager: def __init__(self, path, net, run_config: RunConfig, out_log=True): self.path = path self.net = net self.run_config = run_config self.out_log = out_log self._logs_path, self._save_path = None, None self.best_acc = 0 self.start_epoch = 0 gpu = self.run_config.local_rank torch.cuda.set_device(gpu) # initialize model (default) self.net.init_model(run_config.model_init, run_config.init_div_groups) # net info self.net = self.net.cuda() if run_config.local_rank == 0: self.print_net_info() if self.run_config.sync_bn: self.net = apex.parallel.convert_syncbn_model(self.net) print('local_rank: %d' % self.run_config.local_rank) self.run_config.init_lr = self.run_config.init_lr * float( self.run_config.train_batch_size * self.run_config.world_size) / 256. self.criterion = nn.CrossEntropyLoss() if self.run_config.no_decay_keys: keys = self.run_config.no_decay_keys.split('#') self.optimizer = self.run_config.build_optimizer([ self.net.get_parameters( keys, mode='exclude'), # parameters with weight decay self.net.get_parameters( keys, mode='include'), # parameters without weight decay ]) else: self.optimizer = self.run_config.build_optimizer( self.net.weight_parameters()) # self.net, self.optimizer = amp.initialize(self.net, self.optimizer, opt_level='O1') self.net = DDP(self.net, delay_allreduce=True) cudnn.benchmark = True """ save path and log path """ @property def save_path(self): if self._save_path is None: save_path = os.path.join(self.path, 'checkpoint') os.makedirs(save_path, exist_ok=True) self._save_path = save_path return self._save_path @property def logs_path(self): if self._logs_path is None: logs_path = os.path.join(self.path, 'logs') os.makedirs(logs_path, exist_ok=True) self._logs_path = logs_path return self._logs_path """ net info """ def reset_model(self, model, model_origin=None): self.net = model self.net.init_model(self.run_config.model_init, self.run_config.init_div_groups) if model_origin != None: if self.run_config.local_rank == 0: print('-' * 30 + ' start pruning ' + '-' * 30) get_unpruned_weights(self.net, model_origin) if self.run_config.local_rank == 0: print('-' * 30 + ' end pruning ' + '-' * 30) # net info self.net = self.net.cuda() if self.run_config.local_rank == 0: self.print_net_info() if self.run_config.sync_bn: self.net = apex.parallel.convert_syncbn_model(self.net) print('local_rank: %d' % self.run_config.local_rank) self.criterion = nn.CrossEntropyLoss() if self.run_config.no_decay_keys: keys = self.run_config.no_decay_keys.split('#') self.optimizer = self.run_config.build_optimizer([ self.net.get_parameters( keys, mode='exclude'), # parameters with weight decay self.net.get_parameters( keys, mode='include'), # parameters without weight decay ]) else: self.optimizer = self.run_config.build_optimizer( self.net.weight_parameters()) # model, self.optimizer = amp.initialize(model, self.optimizer, # opt_level='O2', # keep_batchnorm_fp32=True, # loss_scale=1.0 # ) self.net = DDP(self.net, delay_allreduce=True) cudnn.benchmark = True # if model_origin!=None: # if self.run_config.local_rank==0: # print('-'*30+' start training bn '+'-'*30) # self.train_bn(1) # if self.run_config.local_rank==0: # print('-'*30+' end training bn '+'-'*30) # noinspection PyUnresolvedReferences def net_flops(self): data_shape = [1] + list(self.run_config.data_provider.data_shape) net = self.net input_var = torch.zeros(data_shape).cuda() with torch.no_grad(): flops = profile_macs(net, input_var) return flops def print_net_info(self): # parameters total_params = count_parameters(self.net) if self.out_log: print('Total training params: %.2fM' % (total_params / 1e6)) net_info = { 'param': '%.2fM' % (total_params / 1e6), } # flops flops = self.net_flops() if self.out_log: print('Total FLOPs: %.1fM' % (flops / 1e6)) net_info['flops'] = '%.1fM' % (flops / 1e6) # config if self.out_log: print('Net config: ' + str(self.net.config)) net_info['config'] = str(self.net.config) with open('%s/net_info.txt' % self.logs_path, 'w') as fout: fout.write(json.dumps(net_info, indent=4) + '\n') """ save and load models """ def save_model(self, checkpoint=None, is_best=False, model_name=None): if checkpoint is None: checkpoint = {'state_dict': self.net.module.state_dict()} if model_name is None: model_name = 'checkpoint.pth.tar' checkpoint[ 'dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint latest_fname = os.path.join(self.save_path, 'latest.txt') model_path = os.path.join(self.save_path, model_name) with open(latest_fname, 'w') as fout: fout.write(model_path + '\n') torch.save(checkpoint, model_path) if is_best: best_path = os.path.join(self.save_path, 'model_best.pth.tar') torch.save({'state_dict': checkpoint['state_dict']}, best_path) def load_model(self, model_fname=None): latest_fname = os.path.join(self.save_path, 'latest.txt') if model_fname is None and os.path.exists(latest_fname): with open(latest_fname, 'r') as fin: model_fname = fin.readline() if model_fname[-1] == '\n': model_fname = model_fname[:-1] # noinspection PyBroadException try: if model_fname is None or not os.path.exists(model_fname): model_fname = '%s/checkpoint.pth.tar' % self.save_path with open(latest_fname, 'w') as fout: fout.write(model_fname + '\n') if self.out_log: print("=> loading checkpoint '{}'".format(model_fname)) if torch.cuda.is_available(): checkpoint = torch.load(model_fname) else: checkpoint = torch.load(model_fname, map_location='cpu') self.net.module.load_state_dict(checkpoint['state_dict']) # set new manual seed new_manual_seed = int(time.time()) torch.manual_seed(new_manual_seed) torch.cuda.manual_seed_all(new_manual_seed) np.random.seed(new_manual_seed) if 'epoch' in checkpoint: self.start_epoch = checkpoint['epoch'] + 1 if 'best_acc' in checkpoint: self.best_acc = checkpoint['best_acc'] if 'optimizer' in checkpoint: self.optimizer.load_state_dict(checkpoint['optimizer']) if self.out_log: print("=> loaded checkpoint '{}'".format(model_fname)) except Exception: if self.out_log: print('fail to load checkpoint from %s' % self.save_path) def save_config(self, print_info=True): """ dump run_config and net_config to the model_folder """ os.makedirs(self.path, exist_ok=True) net_save_path = os.path.join(self.path, 'net.config') json.dump(self.net.module.config, open(net_save_path, 'w'), indent=4) if print_info: print('Network configs dump to %s' % net_save_path) run_save_path = os.path.join(self.path, 'run.config') json.dump(self.run_config.config, open(run_save_path, 'w'), indent=4) if print_info: print('Run configs dump to %s' % run_save_path) """ train and test """ def write_log(self, log_str, prefix, should_print=True): """ prefix: valid, train, test """ if prefix in ['valid', 'test']: with open(os.path.join(self.logs_path, 'valid_console.txt'), 'a') as fout: fout.write(log_str + '\n') fout.flush() if prefix in ['valid', 'test', 'train']: with open(os.path.join(self.logs_path, 'train_console.txt'), 'a') as fout: if prefix in ['valid', 'test']: fout.write('=' * 10) fout.write(log_str + '\n') fout.flush() if prefix in ['prune']: with open(os.path.join(self.logs_path, 'prune_console.txt'), 'a') as fout: if prefix in ['valid', 'test']: fout.write('=' * 10) fout.write(log_str + '\n') fout.flush() if should_print: print(log_str) def validate(self, is_test=True, net=None, use_train_mode=False, return_top5=False): if is_test: data_loader = self.run_config.test_loader else: data_loader = self.run_config.valid_loader if net is None: net = self.net if use_train_mode: net.train() else: net.eval() batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() # noinspection PyUnresolvedReferences with torch.no_grad(): for i, data in enumerate(data_loader): images, labels = data[0].cuda(non_blocking=True), data[1].cuda( non_blocking=True) # images, labels = data[0].cuda(), data[1].cuda() # compute output output = net(images) loss = self.criterion(output, labels) # measure accuracy and record loss acc1, acc5 = accuracy(output, labels, topk=(1, 5)) reduced_loss = self.reduce_tensor(loss.data) acc1 = self.reduce_tensor(acc1) acc5 = self.reduce_tensor(acc5) losses.update(reduced_loss, images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % self.run_config.print_frequency == 0 or i + 1 == len( data_loader): if is_test: prefix = 'Test' else: prefix = 'Valid' test_log = prefix + ': [{0}/{1}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \ format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1) if return_top5: test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format( top5=top5) print(test_log) self.run_config.valid_loader.reset() self.run_config.test_loader.reset() if return_top5: return losses.avg, top1.avg, top5.avg else: return losses.avg, top1.avg def train_bn(self, epochs=1): if self.run_config.local_rank == 0: print('training bn') for m in self.net.modules(): if isinstance(m, torch.nn.BatchNorm2d): m.running_mean = torch.zeros_like(m.running_mean) m.running_var = torch.ones_like(m.running_var) self.net.train() for i in range(epochs): for _, data in enumerate(self.run_config.train_loader): images, labels = data[0].cuda(non_blocking=True), data[1].cuda( non_blocking=True) output = self.net(images) del output, images, labels if self.run_config.local_rank == 0: print('training bn finished') def train_one_epoch(self, adjust_lr_func, train_log_func, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode self.net.train() end = time.time() for i, data in enumerate(self.run_config.train_loader): data_time.update(time.time() - end) new_lr = adjust_lr_func(i) images, labels = data[0].cuda(non_blocking=True), data[1].cuda( non_blocking=True) # compute output output = self.net(images) if self.run_config.label_smoothing > 0: loss = cross_entropy_with_label_smoothing( output, labels, self.run_config.label_smoothing) else: loss = self.criterion(output, labels) # measure accuracy and record loss acc1, acc5 = accuracy(output, labels, topk=(1, 5)) reduced_loss = self.reduce_tensor(loss.data) acc1 = self.reduce_tensor(acc1) acc5 = self.reduce_tensor(acc5) losses.update(reduced_loss, images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step self.net.zero_grad() # or self.optimizer.zero_grad() loss.backward() self.optimizer.step() torch.cuda.synchronize() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if (i % self.run_config.print_frequency == 0 or i + 1 == len(self.run_config.train_loader) ) and self.run_config.local_rank == 0: batch_log = train_log_func(i, batch_time, data_time, losses, top1, top5, new_lr) self.write_log(batch_log, 'train') return top1, top5 def train(self, print_top5=False): def train_log_func(epoch_, i, batch_time, data_time, losses, top1, top5, lr): batch_log = 'Train [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \ 'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \ format(epoch_ + 1, i, len(self.run_config.train_loader) - 1, batch_time=batch_time, data_time=data_time, losses=losses, top1=top1) if print_top5: batch_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format( top5=top5) batch_log += '\tlr {lr:.5f}'.format(lr=lr) return batch_log for epoch in range(self.start_epoch, self.run_config.n_epochs): if self.run_config.local_rank == 0: print('\n', '-' * 30, 'Train epoch: %d' % (epoch + 1), '-' * 30, '\n') end = time.time() train_top1, train_top5 = self.train_one_epoch( lambda i: self.run_config.adjust_learning_rate( self.optimizer, epoch, i, len(self.run_config.train_loader) ), lambda i, batch_time, data_time, losses, top1, top5, new_lr: train_log_func(epoch, i, batch_time, data_time, losses, top1, top5, new_lr), epoch) time_per_epoch = time.time() - end seconds_left = int( (self.run_config.n_epochs - epoch - 1) * time_per_epoch) if self.run_config.local_rank == 0: print('Time per epoch: %s, Est. complete in: %s' % (str(timedelta(seconds=time_per_epoch)), str(timedelta(seconds=seconds_left)))) if (epoch + 1) % self.run_config.validation_frequency == 0: val_loss, val_acc, val_acc5 = self.validate(is_test=False, return_top5=True) is_best = val_acc > self.best_acc self.best_acc = max(self.best_acc, val_acc) val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})'. \ format(epoch + 1, self.run_config.n_epochs, val_loss, val_acc, self.best_acc) if print_top5: val_log += '\ttop-5 acc {0:.3f}\tTrain top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \ format(val_acc5, top1=train_top1, top5=train_top5) else: val_log += '\tTrain top-1 {top1.avg:.3f}'.format( top1=train_top1) if self.run_config.local_rank == 0: self.write_log(val_log, 'valid') else: is_best = False if self.run_config.local_rank == 0: self.save_model( { 'epoch': epoch, 'best_acc': self.best_acc, 'optimizer': self.optimizer.state_dict(), 'state_dict': self.net.state_dict(), }, is_best=is_best) self.run_config.train_loader.reset() self.run_config.valid_loader.reset() self.run_config.test_loader.reset() def reduce_tensor(self, tensor): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= self.run_config.world_size return rt
def main(): # make save dir if args.local_rank == 0: if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # launch the logger Log.init( log_level=args.log_level, log_file=osp.join(args.save_dir, args.log_file), log_format=args.log_format, rewrite=args.rewrite, stdout_level=args.stdout_level ) # RGB or BGR input(RGB input for ImageNet pretrained models while BGR input for caffe pretrained models) if args.rgb: IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32) IMG_VARS = np.array((0.229, 0.224, 0.225), dtype=np.float32) else: IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32) IMG_VARS = np.array((1, 1, 1), dtype=np.float32) # set models import libs.models as models deeplab = models.__dict__[args.arch](num_classes=args.num_classes, data_set=args.data_set) if args.restore_from is not None: saved_state_dict = torch.load(args.restore_from, map_location=torch.device('cpu')) new_params = deeplab.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not i_parts[0] == 'fc': new_params['.'.join(i_parts[0:])] = saved_state_dict[i] Log.info("load pretrined models") if deeplab.backbone is not None: deeplab.backbone.load_state_dict(new_params, strict=False) else: deeplab.load_state_dict(new_params, strict=False) else: Log.info("train from stracth") args.world_size = 1 if 'WORLD_SIZE' in os.environ and args.apex: args.apex = int(os.environ['WORLD_SIZE']) > 1 args.world_size = int(os.environ['WORLD_SIZE']) print("Total world size: ", int(os.environ['WORLD_SIZE'])) if not args.gpu == None: os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu h, w = args.input_size, args.input_size input_size = (h, w) # Set the device according to local_rank. torch.cuda.set_device(args.local_rank) Log.info("Local Rank: {}".format(args.local_rank)) torch.distributed.init_process_group(backend='nccl', init_method='env://') # set optimizer optimizer = optim.SGD( [{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate}], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # set on cuda deeplab.cuda() # models transformation model = DistributedDataParallel(deeplab) model = apex.parallel.convert_syncbn_model(model) model.train() model.float() model.cuda() # set loss function if args.ohem: criterion = CriterionOhemDSN(thresh=args.ohem_thres, min_kept=args.ohem_keep) # OHEM CrossEntrop if "ic" in args.arch: criterion = CriterionICNet(thresh=args.ohem_thres, min_kept=args.ohem_keep) if "dfa" in args.arch: criterion = CriterionDFANet(thresh=args.ohem_thres, min_kept=args.ohem_keep) else: criterion = CriterionDSN() # CrossEntropy criterion.cuda() cudnn.benchmark = True if args.world_size == 1: print(model) # this is a little different from mul-gpu traning setting in distributed training # because each trainloader is a process that sample from the dataset class. batch_size = args.gpu_num * args.batch_size_per_gpu max_iters = args.num_steps * batch_size / args.gpu_num # set data loader data_set = Cityscapes(args.data_dir, args.data_list, max_iters=max_iters, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN,vars=IMG_VARS, RGB= args.rgb) trainloader = data.DataLoader( data_set, batch_size=args.batch_size_per_gpu, shuffle=True, num_workers=args.num_workers, pin_memory=True) print("trainloader", len(trainloader)) torch.cuda.empty_cache() # start training: for i_iter, batch in enumerate(trainloader): images, labels = batch images = images.cuda() labels = labels.long().cuda() optimizer.zero_grad() lr = adjust_learning_rate(optimizer, args, i_iter, len(trainloader)) preds = model(images) loss = criterion(preds, labels) loss.backward() optimizer.step() reduce_loss = all_reduce_tensor(loss, world_size=args.gpu_num) if args.local_rank == 0: Log.info('iter = {} of {} completed, lr={}, loss = {}'.format(i_iter, len(trainloader), lr, reduce_loss.data.cpu().numpy())) if i_iter % args.save_pred_every == 0 and i_iter > args.save_start: print('save models ...') torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + str(i_iter) + '.pth')) end = timeit.default_timer() if args.local_rank == 0: Log.info("Training cost: "+ str(end - start) + 'seconds') Log.info("Save final models") torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + '_final' + '.pth'))
class Solver(object): def __init__(self): """ :param config: easydict """ self.version = __version__ # logging.info("PyTorch Version {}, Solver Version {}".format(torch.__version__, self.version)) self.distributed = False self.world_size = 1 self.local_rank = 0 self.epoch = 0 self.iteration = 0 self.config = None self.model, self.optimizer, self.lr_policy = None, None, None self.step_decay = 1 self.filtered_keys = None if 'WORLD_SIZE' in os.environ: self.world_size = int(os.environ['WORLD_SIZE']) self.distributed = self.world_size > 1 or torch.cuda.device_count( ) > 1 if self.distributed: dist.init_process_group(backend="nccl", init_method='env://') self.local_rank = dist.get_rank() torch.cuda.set_device(self.local_rank) logging.info( '[distributed mode] world size: {}, local rank: {}.'.format( self.world_size, self.local_rank)) else: logging.info('[Single GPU mode]') def _build_environ(self): if self.config['environ']['deterministic']: cudnn.benchmark = False cudnn.deterministic = True torch.set_printoptions(precision=10) else: cudnn.benchmark = True if self.config['apex']: assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." # set random seed torch.manual_seed(self.config['environ']['seed']) if torch.cuda.is_available(): torch.cuda.manual_seed(self.config['environ']['seed']) np.random.seed(self.config['environ']['seed']) random.seed(self.config['environ']['seed']) # grad clip settings self.grad_clip_params = self.config["solver"]["optimizer"].get( "grad_clip") self.use_grad_clip = True if self.grad_clip_params is not None else False if self.use_grad_clip: logging.info("Using grad clip and params is {}".format( self.grad_clip_params)) else: logging.info("Not Using grad clip.") def init_from_scratch(self, config): t_start = time.time() self.config = config self._build_environ() # model and optimizer self.model = _get_model(self.config) self.filtered_keys = [ p.name for p in inspect.signature(self.model.forward).parameters.values() ] # logging.info("filtered keys:{}".format(self.filtered_keys)) # model_params = filter(lambda p: p.requires_grad, self.model.parameters()) model_params = [] for params in self.model.optimizer_params(): params["lr"] = self.config["solver"]["optimizer"]["params"][ "lr"] * params["lr"] model_params.append(params) self.optimizer = _get_optimizer(config['solver']['optimizer'], model_params=model_params) self.lr_policy = _get_lr_policy(config['solver']['lr_policy'], optimizer=self.optimizer) self.step_decay = config['solver']['step_decay'] if config['model'].get('pretrained_model') is not None: logging.info('loadding pretrained model from {}.'.format( config['model']['pretrained_model'])) load_model(self.model, config['model']['pretrained_model'], distributed=False) self.model.cuda(self.local_rank) if self.distributed: self.model = convert_syncbn_model(self.model) if self.config['apex']['amp_used']: # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. logging.info( "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}." .format(self.config['apex']['opt_level'], self.config['apex']['keep_batchnorm_fp32'], self.config['apex']['loss_scale'])) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self.config['apex']['opt_level'], keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"], loss_scale=self.config['apex']["loss_scale"]) if self.distributed: self.model = DistributedDataParallel(self.model) t_end = time.time() logging.info( "Init trainer from scratch, Time usage: IO: {}".format(t_end - t_start)) def init_from_checkpoint(self, continue_state_object): t_start = time.time() self.config = continue_state_object['config'] self._build_environ() self.model = _get_model(self.config) self.filtered_keys = [ p.name for p in inspect.signature(self.model.forward).parameters.values() ] # model_params = filter(lambda p: p.requires_grad, self.model.parameters()) model_params = [] for params in self.model.optimizer_params(): params["lr"] = self.config["solver"]["optimizer"]["params"][ "lr"] * params["lr"] model_params.append(params) self.optimizer = _get_optimizer(self.config['solver']['optimizer'], model_params=model_params) self.lr_policy = _get_lr_policy(self.config['solver']['lr_policy'], optimizer=self.optimizer) load_model(self.model, continue_state_object['model'], distributed=False) self.model.cuda(self.local_rank) if self.distributed: self.model = convert_syncbn_model(self.model) if self.config['apex']['amp_used']: # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. logging.info( "Initialize Amp. opt level={}, keep batchnorm fp32={}, loss_scale={}." .format(self.config['apex']['opt_level'], self.config['apex']['keep_batchnorm_fp32'], self.config['apex']['loss_scale'])) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self.config['apex']['opt_level'], keep_batchnorm_fp32=self.config['apex']["keep_batchnorm_fp32"], loss_scale=self.config['apex']["loss_scale"]) amp.load_state_dict(continue_state_object['amp']) if self.distributed: self.model = DistributedDataParallel(self.model) self.optimizer.load_state_dict(continue_state_object['optimizer']) self.lr_policy.load_state_dict(continue_state_object['lr_policy']) self.step_decay = self.config['solver']['step_decay'] self.epoch = continue_state_object['epoch'] self.iteration = continue_state_object["iteration"] del continue_state_object t_end = time.time() logging.info( "Init trainer from checkpoint, Time usage: IO: {}".format(t_end - t_start)) def parse_kwargs(self, minibatch): kwargs = { k: v for k, v in minibatch.items() if k in self.filtered_keys } if torch.cuda.is_available(): kwargs = tensor2cuda(kwargs) return kwargs def step(self, **kwargs): """ :param kwargs: :return: """ self.iteration += 1 # loss = self.model(**kwargs) loss, loss_dorn, loss_c3d = self.model(**kwargs) loss_dorn /= self.step_decay loss_c3d /= self.step_decay loss /= self.step_decay # backward if self.distributed and self.config['apex']['amp_used']: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if self.iteration % self.step_decay == 0: if self.use_grad_clip: clip_grad_norm_(self.model.parameters(), **self.grad_clip_params) self.optimizer.step() self.optimizer.zero_grad() self.lr_policy.step(self.epoch) if self.distributed: reduced_loss = reduce_tensor(loss.data, self.world_size) reduced_loss_dorn = reduce_tensor(loss_dorn.data, self.world_size) reduced_loss_c3d = reduce_tensor(loss_c3d.data, self.world_size) else: reduced_loss = loss.data reduced_loss_dorn = loss_dorn.data reduced_loss_c3d = loss_c3d.data # return reduced_loss return reduced_loss, reduced_loss_dorn, reduced_loss_c3d def step_no_grad(self, **kwargs): with torch.no_grad(): out = self.model(**kwargs) return out def before_epoch(self, epoch): synchronize() self.iteration = 0 self.epoch = epoch self.model.train() # self.lr_policy.step(epoch) torch.cuda.empty_cache() def after_epoch(self, epoch=None): synchronize() self.model.eval() # gc.collect() torch.cuda.empty_cache() def save_checkpoint(self, path): if self.local_rank == 0: # logging.info("Saving checkpoint to file {}".format(path)) t_start = time.time() state_dict = {} from collections import OrderedDict new_state_dict = OrderedDict() for k, v in self.model.state_dict().items(): key = k if k.split('.')[0] == 'module': key = k[7:] new_state_dict[key] = v if self.config['apex']['amp_used']: state_dict['amp'] = amp.state_dict() state_dict['config'] = self.config state_dict['model'] = new_state_dict state_dict['optimizer'] = self.optimizer.state_dict() state_dict['lr_policy'] = self.lr_policy.state_dict() state_dict['epoch'] = self.epoch state_dict['iteration'] = self.iteration t_iobegin = time.time() torch.save(state_dict, path) del state_dict del new_state_dict t_end = time.time() logging.info("Save checkpoint to file {}, " "Time usage:\n\tprepare snapshot: {}, IO: {}".format( path, t_iobegin - t_start, t_end - t_iobegin)) def get_learning_rates(self): lrs = [] for i in range(len(self.optimizer.param_groups)): lrs.append(self.optimizer.param_groups[i]['lr']) return lrs
class CdartsTrainer(object): def __init__(self, model_small, model_large, criterion, loaders, samplers, logger=None, regular_coeff=5, regular_ratio=0.2, warmup_epochs=2, fix_head=True, epochs=64, steps_per_epoch=128, fake_batch=128, loss_alpha=2, loss_T=2, distributed=True, log_frequency=10, grad_clip=5.0, interactive_type='kl', output_path='./outputs', w_lr=0.2, w_momentum=0.9, w_weight_decay=3e-4, alpha_lr=0.2, alpha_weight_decay=1e-4, nasnet_lr=0.2, local_rank=0, share_module=True): """ Initialize a CdartsTrainer. Parameters ---------- model_small : nn.Module PyTorch model to be trained. This is the search network of CDARTS. model_large : nn.Module PyTorch model to be trained. This is the evaluation network of CDARTS. criterion : callable Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``. loaders : list of torch.utils.data.DataLoader List of train data and valid data loaders, for training weights and architecture weights respectively. samplers : list of torch.utils.data.Sampler List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed. In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details. logger : logging.Logger The logger for logging. Will use nni logger by default (if logger is ``None``). regular_coeff : float The coefficient of regular loss. regular_ratio : float The ratio of regular loss. warmup_epochs : int The epochs to warmup the search network fix_head : bool ``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads. epochs : int Number of epochs planned for training. steps_per_epoch : int Steps of one epoch. fake_batch : int Batch*fake_batch is used for memory saving. loss_alpha : float The loss coefficient. loss_T : float The loss coefficient. distributed : bool ``True`` if using distributed training, else non-distributed training. log_frequency : int Step count per logging. grad_clip : float Gradient clipping for weights. interactive_type : string ``kl`` or ``smoothl1``. output_path : string Log storage path. w_lr : float Learning rate of the search network parameters. w_momentum : float Momentum of the search and the evaluation network. w_weight_decay : float The weight decay the search and the evaluation network parameters. alpha_lr : float Learning rate of the architecture parameters. alpha_weight_decay : float The weight decay the architecture parameters. nasnet_lr : float Learning rate of the evaluation network parameters. local_rank : int The number of thread. share_module : bool ``True`` if sharing the stem and auxiliary heads, else not sharing these modules. """ if logger is None: logger = logging.getLogger(__name__) train_loader, valid_loader = loaders train_sampler, valid_sampler = samplers self.train_loader = CyclicIterator(train_loader, train_sampler, distributed) self.valid_loader = CyclicIterator(valid_loader, valid_sampler, distributed) self.regular_coeff = regular_coeff self.regular_ratio = regular_ratio self.warmup_epochs = warmup_epochs self.fix_head = fix_head self.epochs = epochs self.steps_per_epoch = steps_per_epoch if self.steps_per_epoch is None: self.steps_per_epoch = min(len(self.train_loader), len(self.valid_loader)) self.fake_batch = fake_batch self.loss_alpha = loss_alpha self.grad_clip = grad_clip if interactive_type == "kl": self.interactive_loss = InteractiveKLLoss(loss_T) elif interactive_type == "smoothl1": self.interactive_loss = nn.SmoothL1Loss() self.loss_T = loss_T self.distributed = distributed self.log_frequency = log_frequency self.main_proc = not distributed or local_rank == 0 self.logger = logger self.checkpoint_dir = output_path if self.main_proc: os.makedirs(self.checkpoint_dir, exist_ok=True) if distributed: torch.distributed.barrier() self.model_small = model_small self.model_large = model_large if self.fix_head: for param in self.model_small.aux_head.parameters(): param.requires_grad = False for param in self.model_large.aux_head.parameters(): param.requires_grad = False self.mutator_small = RegularizedDartsMutator(self.model_small).cuda() self.mutator_large = DartsDiscreteMutator(self.model_large, self.mutator_small).cuda() self.criterion = criterion self.optimizer_small =apex.optimizers.FusedSGD(self.model_small.parameters(), w_lr, momentum=w_momentum, weight_decay=w_weight_decay) self.optimizer_large = apex.optimizers.FusedSGD(self.model_large.parameters(), nasnet_lr, momentum=w_momentum, weight_decay=w_weight_decay) self.optimizer_alpha = apex.optimizers.FusedAdam(self.mutator_small.parameters(), alpha_lr ) if distributed: apex.parallel.convert_syncbn_model(self.model_small) apex.parallel.convert_syncbn_model(self.model_large) self.model_small = DistributedDataParallel(self.model_small, delay_allreduce=True) self.model_large = DistributedDataParallel(self.model_large, delay_allreduce=True) self.mutator_small = RegularizedMutatorParallel(self.mutator_small, delay_allreduce=True) if share_module: self.model_small.callback_queued = True self.model_large.callback_queued = True # mutator large never gets optimized, so do not need parallelized def _warmup(self, phase, epoch): assert phase in [PHASE_SMALL, PHASE_LARGE] if phase == PHASE_SMALL: model, optimizer = self.model_small, self.optimizer_small elif phase == PHASE_LARGE: model, optimizer = self.model_large, self.optimizer_large model.train() meters = AverageMeterGroup() for step in range(self.steps_per_epoch): optimizer.zero_grad() totall_l =0 totall_p =0 for fb in range(self.fake_batch): x, y = next(self.train_loader) x, y = x.cuda(), y.cuda() logits_main, _ = model(x) loss = self.criterion(logits_main, y)/self.fake_batch loss.backward() totall_l += loss prec1,prec1 = accuracy(logits_main, y, topk=(1,1)) prec1 = prec1/self.fake_batch totall_p += prec1 self._clip_grad_norm(model) optimizer.step() metrics = {"prec1": totall_p, "loss": totall_l} metrics = reduce_metrics(metrics, self.distributed) meters.update(metrics) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): self.logger.info("Epoch [%d/%d] Step [%d/%d] (%s) %s", epoch + 1, self.epochs, step + 1, self.steps_per_epoch, phase, meters) def _clip_grad_norm(self, model): if isinstance(model, DistributedDataParallel): nn.utils.clip_grad_norm_(model.module.parameters(), self.grad_clip) else: nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) def _reset_nan(self, parameters): with torch.no_grad(): for param in parameters: for i, p in enumerate(param): if p != p: # equivalent to `isnan(p)` param[i] = float("-inf") def _joint_train(self, epoch): meters = AverageMeterGroup() for step in range(self.steps_per_epoch): totall_lc = 0 totall_lw = 0 totall_li = 0 totall_lr = 0 loss_regular = self.mutator_small.reset_with_loss() reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / ( (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0) if loss_regular: loss_regular *= reg_decay samples_x = [] samples_y = [] criterion_l = [] emsemble_logits_l = [] def trn_l(totall_lc, totall_lw, totall_li, totall_lr): self.model_large.train() self.optimizer_large.zero_grad() for fb in range(self.fake_batch): val_x, val_y = next(self.valid_loader) val_x, val_y = val_x.cuda(), val_y.cuda() logits_main, emsemble_logits_main = self.model_large(val_x) cel = self.criterion(logits_main, val_y) loss_weight = cel / (self.fake_batch) loss_weight.backward(retain_graph=True) criterion_l.append(cel.cpu()) emsemble_logits_l.append(emsemble_logits_main.cpu()) totall_lw += float(loss_weight) samples_x.append(val_x.cpu()) samples_y.append(val_y.cpu()) self._clip_grad_norm(self.model_large) self.optimizer_large.step() self.model_large.train(mode=False) return totall_lc, totall_lw, totall_li, totall_lr totall_lc, totall_lw, totall_li, totall_lr = trn_l(totall_lc, totall_lw, totall_li, totall_lr) def sleep(s): print("--" + str(s)) time.sleep(2) print(torch.cuda.memory_summary()) print("++" + str(s)) def trn_s(totall_lc, totall_lw, totall_li, totall_lr): print("sts") self.model_small.cuda() self.model_small.train() self.optimizer_alpha.zero_grad() self.optimizer_small.zero_grad() i = 0; ls = [] els = [] sleep(0) def sc(): reg_decay = max(self.regular_coeff * (1 - float(epoch - self.warmup_epochs) / ( (self.epochs - self.warmup_epochs) * self.regular_ratio)), 0) loss_regular = self.mutator_small.reset_with_loss() if loss_regular: loss_regular *= reg_decay loss_regular.backward() loss_regular = loss_regular.cpu().detach() sc() sleep(0.5) for i in range(len(samples_x)): val_x = samples_x[i] val_x = val_x.cuda() val_y = samples_y[i] val_y = val_y.cuda() logits_search, emsemble_logits_search = self.model_small(val_x) cls = self.criterion(logits_search, val_y) ls.append(cls.cpu()) els.append(emsemble_logits_search.cpu()) val_x.cpu().detach() val_y.cpu().detach() sleep(1) for i in range(len(samples_x)): criterion_logits_main = criterion_l[i].cuda() cls = ls[i].cuda() emsemble_logits_search = els[i].cuda() loss_weight = cls / (self.fake_batch) totall_lw += float(loss_weight) loss_cls = (cls + criterion_logits_main) / self.loss_alpha / self.fake_batch loss_cls.backward(retain_graph=True) totall_lc += float(loss_cls) criterion_logits_main.cpu().detach() sleep(2) for i in range(len(samples_x)): emsemble_logits_main = emsemble_logits_l[i].cuda() emsemble_logits_search = els[i].cuda() sleep(3) loss_interactive = self.interactive_loss(emsemble_logits_search, emsemble_logits_main) * ( self.loss_T ** 2) * self.loss_alpha / self.fake_batch loss_interactive.backward(retain_graph=True) sleep(5) emsemble_logits_search.cpu() totall_li += float(loss_interactive) totall_lr += float(loss_regular) emsemble_logits_search.cpu().detach() emsemble_logits_main.cpu().detach() sleep(6) i = i + 1 self.optimizer_alpha.step() self._clip_grad_norm(self.model_small) self.optimizer_small.step() self.model_small.train(mode=False) samples_x.clear() samples_y.clear() criterion_l.clear() emsemble_logits_l.clear() return totall_lc, totall_lw, totall_li, totall_lr totall_lc, totall_lw, totall_li, totall_lr = trn_s(totall_lc, totall_lw, totall_li, totall_lr) metrics = {"loss_cls": totall_lc, "loss_interactive": totall_li, "loss_regular": totall_lr, "loss_weight": totall_lw} #metrics = reduce_metrics(metrics, self.distributed) meters.update(metrics) if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): self.logger.info("Epoch [%d/%d] Step [%d/%d] (joint) %s", epoch + 1, self.epochs, step + 1, self.steps_per_epoch, meters) def train(self): for epoch in range(self.epochs): if epoch < self.warmup_epochs: with torch.no_grad(): # otherwise grads will be retained on the architecture params self.mutator_small.reset_with_loss() self._warmup(PHASE_SMALL, epoch) else: with torch.no_grad(): self.mutator_large.reset() self._warmup(PHASE_LARGE, epoch) self._joint_train(epoch) self.export(os.path.join(self.checkpoint_dir, "epoch_{:02d}.json".format(epoch)), os.path.join(self.checkpoint_dir, "epoch_{:02d}.genotypes".format(epoch))) def export(self, file, genotype_file): if self.main_proc: mutator_export, genotypes = self.mutator_small.export(self.logger) with open(file, "w") as f: json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) with open(genotype_file, "w") as f: f.write(str(genotypes))
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--features_h5path", default="/coc/pskynet2/jlu347/multi-modal-bert/data/flick30k/flickr30k.h5", ) # Required parameters parser.add_argument( "--val_file", default="data/flick30k/all_data_final_test_set0_2014.jsonline", type=str, help="The input train corpus.", ) parser.add_argument( "--bert_model", default="bert-base-uncased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", ) parser.add_argument( "--pretrained_weight", default="bert-base-uncased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", ) parser.add_argument( "--output_dir", default="result", type=str, # required=True, help="The output directory where the model checkpoints will be written.", ) parser.add_argument( "--config_file", default="config/bert_config.json", type=str, # required=True, help="The config file which specified the model details.", ) ## Other parameters parser.add_argument( "--max_seq_length", default=30, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.", ) parser.add_argument( "--train_batch_size", default=128, type=int, help="Total batch size for training.", ) parser.add_argument( "--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--num_train_epochs", default=50, type=int, help="Total number of training epochs to perform.", ) parser.add_argument( "--warmup_proportion", default=0.01, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.", ) parser.add_argument( "--no_cuda", action="store_true", help="Whether not to use CUDA when available" ) parser.add_argument( "--do_lower_case", default=True, type=bool, help="Whether to lower case the input text. True for uncased models, False for cased models.", ) parser.add_argument( "--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus", ) parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization" ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumualte before performing a backward/update pass.", ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit", ) parser.add_argument( "--loss_scale", type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n", ) parser.add_argument( "--num_workers", type=int, default=1, help="Number of workers in the dataloader.", ) parser.add_argument( "--from_pretrained", action="store_true", help="Wheter the tensor is from pretrained.", ) parser.add_argument( "--save_name", default="", type=str, help="save name for training." ) parser.add_argument( "--baseline", action="store_true", help="Wheter to use the baseline model (single bert).", ) parser.add_argument( "--zero_shot", action="store_true", help="Wheter directly evaluate." ) args = parser.parse_args() if args.baseline: from pytorch_pretrained_bert.modeling import BertConfig from multimodal_bert.bert import MultiModalBertForImageCaptionRetrieval from multimodal_bert.bert import BertForMultiModalPreTraining else: from multimodal_bert.multi_modal_bert import ( MultiModalBertForImageCaptionRetrieval, BertConfig, ) from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining print(args) if args.save_name is not "": timeStamp = args.save_name else: timeStamp = strftime("%d-%b-%y-%X-%a", gmtime()) timeStamp += "_{:0>6d}".format(random.randint(0, 10e6)) savePath = os.path.join(args.output_dir, timeStamp) if not os.path.exists(savePath): os.makedirs(savePath) config = BertConfig.from_json_file(args.config_file) # save all the hidden parameters. with open(os.path.join(savePath, "command.txt"), "w") as f: print(args, file=f) # Python 3.x print("\n", file=f) print(config, file=f) if args.local_rank == -1 or args.no_cuda: device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" ) n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16 ) ) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps ) ) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # if os.path.exists(args.output_dir) and os.listdir(args.output_dir): # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # train_examples = None num_train_optimization_steps = None print("Loading Train Dataset", args.val_file) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case=args.do_lower_case ) image_features_reader = ImageFeaturesH5Reader(args.features_h5path, True) eval_dset = COCORetreivalDatasetVal(args.val_file, image_features_reader, tokenizer) config.fast_mode = True if args.from_pretrained: if args.zero_shot: model = BertForMultiModalPreTraining.from_pretrained( args.pretrained_weight, config ) else: model = MultiModalBertForImageCaptionRetrieval.from_pretrained( args.pretrained_weight, config, dropout_prob=0.1 ) else: if args.zero_shot: model = BertForMultiModalPreTraining.from_pretrained( args.bert_model, config ) else: model = MultiModalBertForImageCaptionRetrieval.from_pretrained( args.bert_model, config, dropout_prob=0.1 ) if args.fp16: model.half() if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) model.cuda() logger.info("***** Running training *****") logger.info(" Num examples = %d", len(eval_dset)) logger.info(" Batch size = %d", args.train_batch_size) eval_dataloader = DataLoader( eval_dset, shuffle=False, batch_size=1, num_workers=args.num_workers, pin_memory=False, ) startIterID = 0 global_step = 0 masked_loss_v_tmp = 0 masked_loss_t_tmp = 0 next_sentence_loss_tmp = 0 loss_tmp = 0 r1, r5, r10, medr, meanr = evaluate(args, model, eval_dataloader) print("finish evaluation, save result to %s") val_name = args.val_file.split("/")[-1] with open(os.path.join(savePath, val_name + "_result.txt"), "w") as f: print( "r1:%.3f, r5:%.3f, r10:%.3f, mder:%.3f, meanr:%.3f" % (r1, r5, r10, medr, meanr), file=f, )
def main(opts): distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = opts.local_rank, torch.device(opts.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) # Initialize logging task_name = f"{opts.task}-{opts.dataset}" logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/" if rank == 0: logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize, step=opts.step) else: logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=False) logger.print(f"Device: {device}") # Set up random seed torch.manual_seed(opts.random_seed) torch.cuda.manual_seed(opts.random_seed) np.random.seed(opts.random_seed) random.seed(opts.random_seed) # xxx Set up dataloader train_dst, val_dst, test_dst, n_classes = get_dataset(opts) # reset the seed, this revert changes in random seed random.seed(opts.random_seed) train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, sampler=DistributedSampler(train_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers, drop_last=True) val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1, sampler=DistributedSampler(val_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers) logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, Val set: {len(val_dst)}," f" Test set: {len(test_dst)}, n_classes {n_classes}") logger.info(f"Total batch size is {opts.batch_size * world_size}") # xxx Set up model logger.info(f"Backbone: {opts.backbone}") step_checkpoint = None model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step)) logger.info(f"[!] Model made with{'out' if opts.no_pretrained else ''} pre-trained") if opts.step == 0: # if step 0, we don't need to instance the model_old model_old = None else: # instance model_old model_old = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step - 1)) if opts.fix_bn: model.fix_bn() logger.debug(model) # xxx Set up optimizer params = [] if not opts.freeze: params.append({"params": filter(lambda p: p.requires_grad, model.body.parameters()), 'weight_decay': opts.weight_decay}) params.append({"params": filter(lambda p: p.requires_grad, model.head.parameters()), 'weight_decay': opts.weight_decay}) params.append({"params": filter(lambda p: p.requires_grad, model.cls.parameters()), 'weight_decay': opts.weight_decay}) optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True) if opts.lr_policy == 'poly': scheduler = utils.PolyLR(optimizer, max_iters=opts.epochs * len(train_loader), power=opts.lr_power) elif opts.lr_policy == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor) else: raise NotImplementedError logger.debug("Optimizer:\n%s" % optimizer) if model_old is not None: [model, model_old], optimizer = amp.initialize([model.to(device), model_old.to(device)], optimizer, opt_level=opts.opt_level) model_old = DistributedDataParallel(model_old) else: model, optimizer = amp.initialize(model.to(device), optimizer, opt_level=opts.opt_level) # Put the model on GPU model = DistributedDataParallel(model, delay_allreduce=True) # xxx Load old model from old weights if step > 0! if opts.step > 0: # get model path if opts.step_ckpt is not None: path = opts.step_ckpt else: path = f"checkpoints/step/{task_name}_{opts.name}_{opts.step - 1}.pth" # generate model from path if os.path.exists(path): step_checkpoint = torch.load(path, map_location="cpu") model.load_state_dict(step_checkpoint['model_state'], strict=False) # False because of incr. classifiers if opts.init_balanced: # implement the balanced initialization (new cls has weight of background and bias = bias_bkg - log(N+1) model.module.init_new_classifier(device) # Load state dict from the model state dict, that contains the old model parameters model_old.load_state_dict(step_checkpoint['model_state'], strict=True) # Load also here old parameters logger.info(f"[!] Previous model loaded from {path}") # clean memory del step_checkpoint['model_state'] elif opts.debug: logger.info(f"[!] WARNING: Unable to find of step {opts.step - 1}! Do you really want to do from scratch?") else: raise FileNotFoundError(path) # put the old model into distributed memory and freeze it for par in model_old.parameters(): par.requires_grad = False model_old.eval() # xxx Set up Trainer trainer_state = None # if not first step, then instance trainer from step_checkpoint if opts.step > 0 and step_checkpoint is not None: if 'trainer_state' in step_checkpoint: trainer_state = step_checkpoint['trainer_state'] # instance trainer (model must have already the previous step weights) trainer = Trainer(model, model_old, device=device, opts=opts, trainer_state=trainer_state, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step)) # xxx Handle checkpoint for current model (model old will always be as previous step or None) best_score = 0.0 cur_epoch = 0 if opts.ckpt is not None and os.path.isfile(opts.ckpt): checkpoint = torch.load(opts.ckpt, map_location="cpu") model.load_state_dict(checkpoint["model_state"], strict=True) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) cur_epoch = checkpoint["epoch"] + 1 best_score = checkpoint['best_score'] logger.info("[!] Model restored from %s" % opts.ckpt) # if we want to resume training, resume trainer from checkpoint if 'trainer_state' in checkpoint: trainer.load_state_dict(checkpoint['trainer_state']) del checkpoint else: if opts.step == 0: logger.info("[!] Train from scratch") # xxx Train procedure # print opts before starting training to log all parameters logger.add_table("Opts", vars(opts)) if rank == 0 and opts.sample_num > 0: sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False) # sample idxs for visualization logger.info(f"The samples id are {sample_ids}") else: sample_ids = None label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # de-normalization for original images TRAIN = not opts.test val_metrics = StreamSegMetrics(n_classes) results = {} # check if random is equal here. logger.print(torch.randint(0,100, (1,1))) # train/val here while cur_epoch < opts.epochs and TRAIN: # ===== Train ===== model.train() epoch_loss = trainer.train(cur_epoch=cur_epoch, optim=optimizer, train_loader=train_loader, scheduler=scheduler, logger=logger) logger.info(f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0]+epoch_loss[1]}," f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]}") # ===== Log metrics on Tensorboard ===== logger.add_scalar("E-Loss", epoch_loss[0]+epoch_loss[1], cur_epoch) logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch) logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch) # ===== Validation ===== if (cur_epoch + 1) % opts.val_interval == 0: logger.info("validate on val set...") model.eval() val_loss, val_score, ret_samples = trainer.validate(loader=val_loader, metrics=val_metrics, ret_samples_ids=sample_ids, logger=logger) logger.print("Done validation") logger.info(f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss[0]+val_loss[1]}," f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}") logger.info(val_metrics.to_str(val_score)) # ===== Save Best Model ===== if rank == 0: # save best model at the last iteration score = val_score['Mean IoU'] # best model to build incremental steps save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth", model, trainer, optimizer, scheduler, cur_epoch, score) logger.info("[!] Checkpoint saved.") # ===== Log metrics on Tensorboard ===== # visualize validation score and samples logger.add_scalar("V-Loss", val_loss[0]+val_loss[1], cur_epoch) logger.add_scalar("V-Loss-reg", val_loss[1], cur_epoch) logger.add_scalar("V-Loss-cls", val_loss[0], cur_epoch) logger.add_scalar("Val_Overall_Acc", val_score['Overall Acc'], cur_epoch) logger.add_scalar("Val_MeanIoU", val_score['Mean IoU'], cur_epoch) logger.add_table("Val_Class_IoU", val_score['Class IoU'], cur_epoch) logger.add_table("Val_Acc_IoU", val_score['Class Acc'], cur_epoch) # logger.add_figure("Val_Confusion_Matrix", val_score['Confusion Matrix'], cur_epoch) # keep the metric to print them at the end of training results["V-IoU"] = val_score['Class IoU'] results["V-Acc"] = val_score['Class Acc'] for k, (img, target, lbl) in enumerate(ret_samples): img = (denorm(img) * 255).astype(np.uint8) target = label2color(target).transpose(2, 0, 1).astype(np.uint8) lbl = label2color(lbl).transpose(2, 0, 1).astype(np.uint8) concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width logger.add_image(f'Sample_{k}', concat_img, cur_epoch) cur_epoch += 1 # ===== Save Best Model at the end of training ===== if rank == 0 and TRAIN: # save best model at the last iteration # best model to build incremental steps save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth", model, trainer, optimizer, scheduler, cur_epoch, best_score) logger.info("[!] Checkpoint saved.") torch.distributed.barrier() # xxx From here starts the test code logger.info("*** Test the model on all seen classes...") # make data loader test_loader = data.DataLoader(test_dst, batch_size=opts.batch_size if opts.crop_val else 1, sampler=DistributedSampler(test_dst, num_replicas=world_size, rank=rank), num_workers=opts.num_workers) # load best model if TRAIN: model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step)) # Put the model on GPU model = DistributedDataParallel(model.cuda(device)) ckpt = f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth" checkpoint = torch.load(ckpt, map_location="cpu") model.load_state_dict(checkpoint["model_state"]) logger.info(f"*** Model restored from {ckpt}") del checkpoint trainer = Trainer(model, None, device=device, opts=opts) model.eval() val_loss, val_score, _ = trainer.validate(loader=test_loader, metrics=val_metrics, logger=logger) logger.print("Done test") logger.info(f"*** End of Test, Total Loss={val_loss[0]+val_loss[1]}," f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}") logger.info(val_metrics.to_str(val_score)) logger.add_table("Test_Class_IoU", val_score['Class IoU']) logger.add_table("Test_Class_Acc", val_score['Class Acc']) logger.add_figure("Test_Confusion_Matrix", val_score['Confusion Matrix']) results["T-IoU"] = val_score['Class IoU'] results["T-Acc"] = val_score['Class Acc'] logger.add_results(results) logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'], opts.step) logger.add_scalar("T_MeanIoU", val_score['Mean IoU'], opts.step) logger.add_scalar("T_MeanAcc", val_score['Mean Acc'], opts.step) logger.close()
class NetworkFactory(object): def __init__(self, system_config, model, distributed=False, gpu=None): super(NetworkFactory, self).__init__() self.system_config = system_config self.gpu = gpu self.model = DummyModule(model) self.loss = model.loss self.network = Network(self.model, self.loss) if distributed: from apex.parallel import DistributedDataParallel, convert_syncbn_model torch.cuda.set_device(gpu) self.network = self.network.cuda(gpu) self.network = convert_syncbn_model(self.network) self.network = DistributedDataParallel(self.network) else: # self.network = DataParallel(self.network, chunk_sizes=system_config.chunk_sizes) pass total_params = 0 for params in self.model.parameters(): num_params = 1 for x in params.size(): num_params *= x total_params += num_params print("total parameters: {}".format(total_params)) if system_config.opt_algo == "adam": self.optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.model.parameters())) elif system_config.opt_algo == "sgd": self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=system_config.learning_rate, momentum=0.9, weight_decay=0.0001) else: raise ValueError("unknown optimizer") def cuda(self): self.model.cuda() def train_mode(self): self.network.train() def eval_mode(self): self.network.eval() def _t_cuda(self, xs): if type(xs) is list: return [x.cuda(self.gpu, non_blocking=True) for x in xs] return xs.cuda(self.gpu, non_blocking=True) def train(self, xs, ys, **kwargs): xs = [self._t_cuda(x) for x in xs] ys = [self._t_cuda(y) for y in ys] self.optimizer.zero_grad() loss = self.network(xs, ys) loss = loss.mean() loss.backward() self.optimizer.step() return loss def validate(self, xs, ys, **kwargs): with torch.no_grad(): xs = [self._t_cuda(x) for x in xs] ys = [self._t_cuda(y) for y in ys] loss = self.network(xs, ys) loss = loss.mean() return loss def test(self, xs, **kwargs): with torch.no_grad(): xs = [self._t_cuda(x) for x in xs] return self.model(*xs, **kwargs) def set_lr(self, lr): print("setting learning rate to: {}".format(lr)) for param_group in self.optimizer.param_groups: param_group["lr"] = lr def load_pretrained_params(self, pretrained_model): print("loading from {}".format(pretrained_model)) with open(pretrained_model, "rb") as f: params = torch.load(f) self.model.load_state_dict(params) def load_params(self, iteration): cache_file = self.system_config.snapshot_file.format(iteration) print("loading model from {}".format(cache_file)) with open(cache_file, "rb") as f: params = torch.load(f) self.model.load_state_dict(params) def save_params(self, iteration): cache_file = self.system_config.snapshot_file.format(iteration) print("saving model to {}".format(cache_file)) with open(cache_file, "wb") as f: params = self.model.state_dict() torch.save(params, f)
class TRans2InfoMax(Trans2Net): def __init__(self, cfg, writer=None): super(TRans2InfoMax, self).__init__(cfg, writer) def _define_networks(self): self.net = networks.Source_Model(self.cfg, device=self.device) self.cross_encoder = networks.Cross_Model(self.cfg, device=self.device) self.d_distribute = networks.GANDiscriminator(self.cfg, device=self.device) self.model_names = ['net', 'cross_encoder', 'd_distribute'] networks.print_network(self.net) networks.print_network(self.cross_encoder) networks.print_network(self.d_distribute) if 'PIX2PIX' in self.cfg.LOSS_TYPES: criterion_pix2pix = torch.nn.L1Loss() self.cross_encoder.set_pix2pix_criterion(criterion_pix2pix) def set_device(self): if not self.cfg.MULTIPROCESSING_DISTRIBUTED: self.net = nn.DataParallel(self.net).to(self.device) self.cross_encoder = nn.DataParallel(self.cross_encoder).to( self.device) self.d_distribute = nn.DataParallel(self.d_distribute).to( self.device) def set_optimizer(self, cfg): self.optimizers = [] # if self.cfg.RESUME: # self.params_list = [] # self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4] # self.modules_sc = [self.net.evaluator] # # # for module in self.modules_ft: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR)) # for module in self.modules_sc: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 10)) # self.optimizer_g = torch.optim.Adam(self.params_list, lr=cfg.LR, betas=(0.5, 0.999)) # else: self.optimizer_g = torch.optim.Adam(self.net.parameters(), lr=cfg.LR, betas=(0.5, 0.999)) self.optimizer_c = torch.optim.Adam(self.cross_encoder.parameters(), lr=cfg.LR, betas=(0.5, 0.999)) self.optimizer_d = torch.optim.SGD(self.d_distribute.parameters(), lr=cfg.LR, momentum=0.9, weight_decay=0.0005) if cfg.MULTIPROCESSING_DISTRIBUTED: if cfg.USE_APEX: self.net, self.optimizer_g = apex.amp.initialize( self.net.cuda(), self.optimizer_g, opt_level=cfg.opt_level) self.cross_encoder, self.optimizer_c = apex.amp.initialize( self.cross_encoder.cuda(), self.optimizer_c, opt_level=cfg.opt_level) self.d_distribute, self.optimizer_d = apex.amp.initialize( self.d_distribute.cuda(), self.optimizer_d, opt_level=cfg.opt_level) self.net = DDP(self.net) self.cross_encoder = DDP(self.cross_encoder) self.d_distribute = DDP(self.d_distribute) else: self.net = torch.nn.parallel.DistributedDataParallel( self.net.cuda(), device_ids=[cfg.gpu]) self.cross_encoder = torch.nn.parallel.DistributedDataParallel( self.cross_encoder.cuda(), device_ids=[cfg.gpu]) self.d_distribute = torch.nn.parallel.DistributedDataParallel( self.d_distribute.cuda(), device_ids=[cfg.gpu]) self.optimizers.append(self.optimizer_d) self.optimizers.append(self.optimizer_g) self.optimizers.append(self.optimizer_c) # def get_patch(self, img): # # # Input of the function is a tensor [B, C, H, W] # # Output of the functions is a tensor [B * 49, C, 64, 64] # # patch_batch = None # all_patches_list = [] # # for y_patch in range(3): # for x_patch in range(3): # y1 = y_patch * 64 # y2 = y1 + 128 # # x1 = x_patch * 64 # x2 = x1 + 128 # # img_patches = img[:, :, y1:y2, x1:x2] # Batch(img_idx in batch), channels xrange, yrange # img_patches = img_patches.unsqueeze(dim=1) # all_patches_list.append(img_patches) # # # print(patch_batch.shape) # all_patches_tensor = torch.cat(all_patches_list, dim=1) # # patches_per_image = [] # for b in range(all_patches_tensor.shape[0]): # patches_per_image.append(all_patches_tensor[b]) # # patch_batch = torch.cat(patches_per_image, dim=0) # return patch_batch # encoder-decoder branch def _forward(self, class_only=False): # if self.phase == 'train': # self.source_modal = self.get_patch(self.source_modal) # self.target_modal = self.get_patch(self.target_modal) # # self.source_modal.view(self.batch_size, 3, 3, -1) if self.label is not None: label = self.label else: label = None self.result_g = self.net(self.source_modal, target=self.target_modal, label=label, class_only=class_only) if self.phase == 'train' and not self.cfg.NO_TRANS: self.result_c = self.cross_encoder(self.result_g['gen_cross'], target=self.target_modal) def _optimize(self, iter): self._forward() if 'GAN' in self.cfg.LOSS_TYPES: self.set_requires_grad([self.cross_encoder, self.net], False) self.set_requires_grad(self.d_distribute, True) fake_d = torch.cat( (self.result_c['feat_gen'], self.result_c['feat_target']), 1) real_d = torch.cat( (self.result_c['feat_target'], self.result_c['feat_target']), 1) # fake_d = self.result_c['feat_gen'] # real_d = self.result_c['feat_target'] if self.cfg.MULTIPROCESSING_DISTRIBUTED: loss_d_fake = self.d_distribute(fake_d.detach(), False) loss_d_true = self.d_distribute(real_d.detach(), True) else: loss_d_fake = self.d_distribute(fake_d.detach(), False).mean() loss_d_true = self.d_distribute(real_d.detach(), True).mean() loss_d = (loss_d_fake + loss_d_true) * 0.5 self.loss_meters['TRAIN_GAN_D_LOSS'].update( loss_d.item(), self.batch_size) self.optimizer_d.zero_grad() if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED: with apex.amp.scale_loss(loss_d, self.optimizer_d) as scaled_loss: scaled_loss.backward() else: loss_d.backward() self.optimizer_d.step() # G loss_g = self._construct_loss(iter) self.set_requires_grad([self.cross_encoder, self.net], True) if self.d_distribute is not None: self.set_requires_grad(self.d_distribute, False) self.optimizer_c.zero_grad() self.optimizer_g.zero_grad() if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED: with apex.amp.scale_loss( loss_g, [self.optimizer_c, self.optimizer_g]) as scaled_loss: scaled_loss.backward() else: loss_g.backward() self.optimizer_c.step() self.optimizer_g.step() def _construct_loss(self, iter=None): loss_total = torch.zeros(1).to(self.device) # decay_coef = 1 decay_coef = (iter / self.cfg.NITER_TOTAL) # small to big if 'PIX2PIX' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: local_loss = self.result_c[ 'pix2pix_loss'] * self.cfg.ALPHA_LOCAL loss_total += local_loss else: local_loss = self.result_c['pix2pix_loss'].mean( ) * self.cfg.ALPHA_LOCAL self.loss_meters['TRAIN_PIX2PIX_LOSS'].update( local_loss.item(), self.batch_size) if 'PRIOR' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: prior_loss = self.result_g['prior_loss'] * self.cfg.ALPHA_PRIOR loss_total += prior_loss else: prior_loss = self.result_g['prior_loss'].mean( ) * self.cfg.ALPHA_PRIOR self.loss_meters['TRAIN_PRIOR_LOSS'].update( prior_loss.item(), self.batch_size) if 'CROSS' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: cross_loss = self.result_c['cross_loss'] * self.cfg.ALPHA_CROSS # cross_loss_self = self.result_c['cross_loss_self'] * self.cfg.ALPHA_CROSS * 0.2 else: cross_loss = self.result_c['cross_loss'].mean( ) * self.cfg.ALPHA_CROSS # cross_loss_self = self.result_c['cross_loss_self'].mean() * self.cfg.ALPHA_CROSS * decay_coef loss_total += cross_loss # loss_total += cross_loss_self # loss_total += cross_loss self.loss_meters['TRAIN_CROSS_LOSS'].update( cross_loss.item(), self.batch_size) # self.loss_meters['TRAIN_CROSS_LOSS_SELF'].update(cross_loss_self.item(), self.batch_size) if 'H**O' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: homo_loss = self.result_g['homo_loss'] * self.cfg.ALPHA_CROSS loss_total += homo_loss else: homo_loss = self.result_g['homo_loss'].mean( ) * self.cfg.ALPHA_CROSS self.loss_meters['TRAIN_HOMO_LOSS'].update(homo_loss.item(), self.batch_size) if 'CLS' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: cls_loss = self.result_g['cls_loss'] else: cls_loss = self.result_g['cls_loss'].mean() self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss.item(), self.batch_size) loss_total += cls_loss if 'GAN' in self.cfg.LOSS_TYPES: # real_g = self.result_c['feat_gen'] real_g = torch.cat( (self.result_c['feat_gen'], self.result_c['feat_target']), 1) if self.cfg.MULTIPROCESSING_DISTRIBUTED: loss_gan_g = self.d_distribute(real_g, True) * self.cfg.ALPHA_GAN else: loss_gan_g = self.d_distribute( real_g, True).mean() * self.cfg.ALPHA_GAN self.loss_meters['TRAIN_GAN_G_LOSS'].update( loss_gan_g.item(), self.batch_size) loss_total += loss_gan_g return loss_total def set_log_data(self, cfg): super().set_log_data(cfg) self.log_keys = [ 'TRAIN_CROSS_LOSS', 'TRAIN_CROSS_LOSS_SELF', 'TRAIN_HOMO_LOSS', 'TRAIN_LOCAL_LOSS', 'TRAIN_PRIOR_LOSS', 'INTERSECTION_MLP', 'LABEL_MLP', 'INTERSECTION_LIN', 'LABEL_LIN', 'VAL_CLS_ACC_MLP', 'VAL_CLS_MEAN_ACC_MLP', 'TRAIN_GAN_D_LOSS', 'TRAIN_GAN_G_LOSS', 'TRAIN_CONTRASTIVE_LOSS' ] for item in self.log_keys: self.loss_meters[item] = AverageMeter() def evaluate(self): # evaluate model on test_loader self.net.eval() self.phase = 'test' intersection_meter_mlp = self.loss_meters['INTERSECTION_MLP'] target_meter_mlp = self.loss_meters['LABEL_MLP'] for i, data in enumerate(self.val_loader): self.set_input(data) with torch.no_grad(): self._forward(class_only=True) pred = self.result_g['pred'].data.max(1)[1] # lgt_glb_mlp = lgt_glb_mlp # lgt_glb_lin = lgt_glb_lin.data.max(1)[1] # [lgt_glb_mlp, lgt_glb_lin] = self.result_g['pred'] # lgt_glb_mlp = lgt_glb_mlp.data.max(1)[1] # lgt_glb_lin = lgt_glb_lin.data.max(1)[1] intersection_mlp, union_mlp, label_mlp = util.intersectionAndUnionGPU( pred, self.label, self.cfg.NUM_CLASSES) if self.cfg.MULTIPROCESSING_DISTRIBUTED: dist.all_reduce(intersection_mlp), dist.all_reduce( union_mlp), dist.all_reduce(label_mlp) intersection_mlp, union_mlp, label_mlp = intersection_mlp.cpu( ).numpy(), union_mlp.cpu().numpy(), label_mlp.cpu().numpy() intersection_meter_mlp.update(intersection_mlp, self.batch_size) target_meter_mlp.update(label_mlp, self.batch_size) # Mean ACC allAcc_mlp = sum( intersection_meter_mlp.sum) / (sum(target_meter_mlp.sum) + 1e-10) accuracy_class_mlp = intersection_meter_mlp.sum / ( target_meter_mlp.sum + 1e-10) mAcc_mlp = np.mean(accuracy_class_mlp) self.loss_meters['VAL_CLS_ACC_MLP'].update(allAcc_mlp) self.loss_meters['VAL_CLS_MEAN_ACC_MLP'].update(mAcc_mlp) def write_loss(self, phase, global_step): task = self.cfg.TASK_TYPE self.writer.add_image(task + '/rgb', torchvision.utils.make_grid( self.source_modal[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) if phase == 'train': if not self.cfg.NO_TRANS: for k, v in self.result_g.items(): if 'gen' in k: # if isinstance(self.result_g[k], list): # for i, (gen, _depth) in enumerate(zip(self.result_g['gen'], self.target_modal)): # self.writer.add_image(task + '/' + k + str(self.cfg.FINE_SIZE[0] / pow(2, i)), # torchvision.utils.make_grid(gen[:6].clone().cpu().data, 3, # normalize=True), # global_step=global_step) # self.writer.add_image(task + '/target' + str(self.cfg.FINE_SIZE[0] / pow(2, i)), # torchvision.utils.make_grid(_depth[:6].clone().cpu().data, 3, # normalize=True), # global_step=global_step) # else: self.writer.add_image( task + '/' + k, torchvision.utils.make_grid( self.result_g[k][:6].clone().cpu().data, 3, normalize=True), global_step=global_step) self.writer.add_image( task + '/target', torchvision.utils.make_grid( self.target_modal[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) # self.writer.add_image(task + '/target_neg', # torchvision.utils.make_grid(self.target_modal_neg[:6].clone().cpu().data, 3, # # normalize=True), global_step=global_step) self.writer.add_scalar(task + '/LR', self.optimizer_g.param_groups[0]['lr'], global_step=global_step) for k, v in self.loss_meters.items(): if 'LOSS' in k and v.avg > 0: self.writer.add_scalar(task + '/' + k, v.avg, global_step=global_step) elif phase == 'test': for k, v in self.loss_meters.items(): if ('MEAN' in k or 'ACC' in k) and v.val > 0: self.writer.add_scalar(task + '/' + k, v.val * 100.0, global_step=global_step)
def main(): parser = argparse.ArgumentParser() # Required parameters # Data files for VQA task. parser.add_argument("--features_h5path", default="data/coco/test2015.h5") parser.add_argument( "--train_file", default="data/VQA/training", type=str, # required=True, help="The input train corpus.", ) parser.add_argument( "--bert_model", default="bert-base-uncased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", ) parser.add_argument( "--pretrained_weight", default="bert-base-uncased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", ) parser.add_argument( "--output_dir", default="save", type=str, # required=True, help= "The output directory where the model checkpoints will be written.", ) parser.add_argument( "--config_file", default="config/bert_config.json", type=str, # required=True, help="The config file which specified the model details.", ) ## Other parameters parser.add_argument( "--max_seq_length", default=30, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.", ) parser.add_argument("--use_location", action="store_true", help="whether use location.") parser.add_argument( "--train_batch_size", default=128, type=int, help="Total batch size for training.", ) parser.add_argument( "--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.", ) parser.add_argument( "--num_train_epochs", default=30, type=int, help="Total number of training epochs to perform.", ) parser.add_argument( "--warmup_proportion", default=0.01, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.", ) parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") parser.add_argument( "--do_lower_case", default=True, type=bool, help= "Whether to lower case the input text. True for uncased models, False for cased models.", ) parser.add_argument( "--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus", ) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help= "Number of updates steps to accumualte before performing a backward/update pass.", ) parser.add_argument( "--fp16", action="store_true", help="Whether to use 16-bit float precision instead of 32-bit", ) parser.add_argument( "--loss_scale", type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n", ) parser.add_argument( "--num_workers", type=int, default=20, help="Number of workers in the dataloader.", ) parser.add_argument( "--from_pretrained", action="store_true", help="Wheter the tensor is from pretrained.", ) parser.add_argument("--save_name", default="", type=str, help="save name for training.") parser.add_argument( "--baseline", action="store_true", help="Wheter to use the baseline model (single bert).", ) parser.add_argument("--split", default="test", type=str, help="train or trainval.") parser.add_argument( "--use_chunk", default=0, type=float, help="whether use chunck for parallel training.", ) args = parser.parse_args() if args.baseline: from pytorch_pretrained_bert.modeling import BertConfig from multimodal_bert.bert import MultiModalBertForVQA else: from multimodal_bert.multi_modal_bert import MultiModalBertForVQA, BertConfig print(args) if args.save_name is not "": timeStamp = args.save_name else: timeStamp = strftime("%d-%b-%y-%X-%a", gmtime()) timeStamp += "_{:0>6d}".format(random.randint(0, 10e6)) savePath = os.path.join(args.output_dir, timeStamp) if not os.path.exists(savePath): os.makedirs(savePath) config = BertConfig.from_json_file(args.config_file) # save all the hidden parameters. with open(os.path.join(savePath, "command.txt"), "w") as f: print(args, file=f) # Python 3.x print("\n", file=f) print(config, file=f) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # train_examples = None num_train_optimization_steps = None tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) image_features_reader = ImageFeaturesH5Reader(args.features_h5path, True) if args.split == "minval": eval_dset = VQAClassificationDataset("minval", image_features_reader, tokenizer, dataroot="data/VQA") elif args.split == "test": eval_dset = VQAClassificationDataset("test", image_features_reader, tokenizer, dataroot="data/VQA") elif args.split == "val": eval_dset = VQAClassificationDataset("val", image_features_reader, tokenizer, dataroot="data/VQA") elif args.split == "test-dev": eval_dset = VQAClassificationDataset("test-dev", image_features_reader, tokenizer, dataroot="data/VQA") num_labels = eval_dset.num_ans_candidates if args.from_pretrained: model = MultiModalBertForVQA.from_pretrained(args.pretrained_weight, config, num_labels=num_labels) else: model = MultiModalBertForVQA.from_pretrained(args.bert_model, config, num_labels=num_labels) if args.fp16: model.half() if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model) elif n_gpu > 1: model = DataParallel(model, use_chuncks=args.use_chunk) model.cuda() logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_dset)) logger.info(" Batch size = %d", args.train_batch_size) eval_dataloader = DataLoader( eval_dset, shuffle=False, batch_size=args.train_batch_size, num_workers=args.num_workers, pin_memory=True, ) startIterID = 0 global_step = 0 masked_loss_v_tmp = 0 masked_loss_t_tmp = 0 next_sentence_loss_tmp = 0 loss_tmp = 0 start_t = timer() model.train(False) eval_score, bound = evaluate(args, model, eval_dataloader) logger.info("\teval score: %.2f (%.2f)" % (100 * eval_score, 100 * bound))
def main_worker(gpu, ngpus_per_node, args): dist.init_process_group(backend='nccl') torch.cuda.set_device(gpu) #################################################data PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') WEIGHTS_PATH = Path('./weights/train_1024/') WEIGHTS_PATH.mkdir(exist_ok=True) batch_size = 60 normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) ####################################### train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset) train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=train_sampler) ####################################### val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset) val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, num_workers=2,pin_memory=True,sampler=val_sampler) ####################################### test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=test_sampler) ####################################### print("Train: %d" %len(train_loader.dataset.imgs)) print("Val: %d" %len(val_loader.dataset.imgs)) print("Test: %d" %len(test_loader.dataset.imgs)) print("Classes: %d" % len(train_loader.dataset.classes)) inputs, targets = next(iter(train_loader)) print("Inputs: ", inputs.size()) print("Targets: ", targets.size()) # utils.imgs.view_image(inputs[0]) # utils.imgs.view_annotated(targets[0]) EE = 4 device = 'cuda' EE_size = 256 LR = 1e-3 model = UNet4(n_channels=3, n_classes=4) # gpu = args.local_rank torch.cuda.set_device(gpu) model.cuda(gpu) # model = model.to(device) # model = torch.nn.DataParallel(model).cuda() ######################################## optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) model, optimizer = amp.initialize(model,optimizer) model = DistributedDataParallel(model) ################################################### # print('EE, model', EE, model) pred_dir = './train_PG_pred/' FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' ############################# EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# epoch_num = 10000 best_loss = 1. best_dice = 0. LR_DECAY = 0.95 DECAY_EVERY_N_EPOCHS = 10 criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda(gpu) cudnn.benchmark = True for epoch in range(1, epoch_num): start_time = datetime.datetime.now() print('start_time',start_time) model = model.cuda() ################################################## ### Train ### trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) ## Test ### val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) ### Checkpoint ### DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) print('-----------test_dice',DICE1) if best_dice < DICE1: # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) best_dice = DICE1 ### Adjust Lr ### adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) end_time = datetime.datetime.now() print('end_time', end_time) print('time', (end_time - start_time).seconds)
class Trans2Net(BaseModel): def __init__(self, cfg, writer=None, batch_norm=nn.BatchNorm2d): super(Trans2Net, self).__init__(cfg) super().__init__(cfg) self.phase = cfg.PHASE self.trans = not cfg.NO_TRANS self.content_model = None self.writer = writer self.batch_size_train = cfg.BATCH_SIZE_TRAIN self.batch_size_val = cfg.BATCH_SIZE_VAL self.batch_norm = batch_norm self._define_networks() self.params_list = [] # self.set_criterion(cfg) def _define_networks(self): networks.batch_norm = self.batch_norm self.net = networks.define_netowrks(self.cfg, device=self.device) self.model_names = ['net'] if 'GAN' in self.cfg.LOSS_TYPES: self.discriminator = networks.GANDiscriminator_Image( self.cfg, device=self.device) self.model_names.append('discriminator') # if 'PSP' in cfg.MODEL: # self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4] # self.modules_sc = [self.net.ppm, self.net.cls, self.net.aux, self.net.score_aux1, self.net.score_aux2] # # if self.trans: # self.modules_ft.extend( # [self.net.up0, self.net.up1, self.net.up2, self.net.up3, # self.net.up4, self.net.up5, self.net.up_seg]) # # for module in self.modules_sc: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 5)) # for module in self.modules_ft: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR)) # else: # # self.modules_ft = [self.net.layer0, self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4] # self.modules_sc = [self.net.score_head, self.net.score_aux1, self.net.score_aux2] # if self.trans: # self.modules_sc.extend( # [self.net.up1, self.net.up2, self.net.up3, self.net.up4, self.net.up5, self.net.up_image]) # for module in self.modules_sc: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR * 5)) # for module in self.modules_ft: # self.params_list.append(dict(params=module.parameters(), lr=cfg.LR)) if self.cfg.USE_FAKE_DATA or self.cfg.USE_COMPL_DATA: print('Use fake data: sample model is {0}'.format( self.cfg.SAMPLE_MODEL_PATH)) print('fake ratio:', self.cfg.FAKE_DATA_RATE) cfg_sample = copy.deepcopy(self.cfg) cfg_sample.USE_FAKE_DATA = False cfg_sample.USE_COMPL_DATA = False cfg_sample.NO_TRANS = False cfg_sample.MODEL = 'trecg_compl' model = networks.define_netowrks(cfg_sample, device=self.device) checkpoint_path = os.path.join(self.cfg.CHECKPOINTS_DIR, self.cfg.SAMPLE_MODEL_PATH) self._load_checkpoint(model, checkpoint_path, key='net', keep_fc=False) # for mit 67 # self.net = copy.deepcopy(model.compl_net) model.eval() if self.cfg.USE_COMPL_DATA: self.net.set_sample_model(model) else: self.sample_model = nn.DataParallel(model).to(self.device) networks.print_network(self.net) # print('Use fake data: sample model is {0}'.format(cfg.SAMPLE_MODEL_PATH)) # print('fake ratio:', cfg.FAKE_DATA_RATE) # sample_model_path = cfg.SAMPLE_MODEL_PATH # cfg_sample = copy.deepcopy(cfg) # cfg_sample.USE_FAKE_DATA = False # model = networks.define_netowrks(cfg_sample, device=self.device) # self.load_checkpoint(net=model, checkpoint_path=sample_model_path) # model.eval() # self.sample_model = nn.DataParallel(model).to(self.device) def set_device(self): if not self.cfg.MULTIPROCESSING_DISTRIBUTED: self.net = nn.DataParallel(self.net).to(self.device) if 'GAN' in self.cfg.LOSS_TYPES: self.discriminator = nn.DataParallel(self.discriminator).to( self.device) def _optimize(self, iter): self._forward() if 'GAN' in self.cfg.LOSS_TYPES: self.set_requires_grad(self.net, False) self.set_requires_grad(self.discriminator, True) fake_d = self.result['gen_img'] real_d = self.target_modal # fake_d = self.result_c['feat_gen'] # real_d = self.result_c['feat_target'] if self.cfg.MULTIPROCESSING_DISTRIBUTED: loss_d_fake = self.discriminator(fake_d.detach(), False) loss_d_true = self.discriminator(real_d.detach(), True) else: loss_d_fake = self.discriminator(fake_d.detach(), False).mean() loss_d_true = self.discriminator(real_d.detach(), True).mean() loss_d = (loss_d_fake + loss_d_true) * 0.5 self.loss_meters['TRAIN_GAN_D_LOSS'].update( loss_d.item(), self.batch_size) self.optimizer_d.zero_grad() if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED: with apex.amp.scale_loss(loss_d, self.optimizer_d) as scaled_loss: scaled_loss.backward() else: loss_d.backward() self.optimizer_d.step() loss_g = self._construct_loss(iter) if 'GAN' in self.cfg.LOSS_TYPES and self.discriminator is not None: self.set_requires_grad(self.discriminator, False) self.set_requires_grad(self.net, True) self.optimizer.zero_grad() if self.cfg.USE_APEX and self.cfg.MULTIPROCESSING_DISTRIBUTED: with apex.amp.scale_loss(loss_g, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss_g.backward() self.optimizer.step() def set_criterion(self, cfg): if 'CLS' in self.cfg.LOSS_TYPES or self.cfg.EVALUATE: criterion_cls = util.CrossEntropyLoss( weight=cfg.CLASS_WEIGHTS_TRAIN, device=self.device, ignore_index=cfg.IGNORE_LABEL) self.net.set_cls_criterion(criterion_cls) if 'SEMANTIC' in self.cfg.LOSS_TYPES: criterion_content = torch.nn.L1Loss() content_model = networks.Content_Model(cfg, criterion_content).to( self.device) self.net.set_content_model(content_model) if 'PIX2PIX' in self.cfg.LOSS_TYPES: criterion_pix2pix = torch.nn.L1Loss() self.net.set_pix2pix_criterion(criterion_pix2pix) def set_input(self, data): self._source = data['image'] self.source_modal = self._source.to(self.device) self.batch_size = self._source.size()[0] if 'label' in data.keys(): self._label = data['label'] self.label = torch.LongTensor(self._label).to(self.device) else: self.label = None if self.cfg.TARGET_MODAL: if self.cfg.MULTI_SCALE: self.target_modal = data[self.cfg.TARGET_MODAL][-1].to( self.device) else: self.target_modal = data[self.cfg.TARGET_MODAL].to(self.device) else: self.target_modal = None # if self.trans or self.cfg.RESUME: # if not self.cfg.MULTI_SCALE: # self.target_modal = self.target_modal # else: # if self.cfg.WHICH_DIRECTION == 'BtoA': # self.source_modal, self.target_modal = self.target_modal, self.source_modal def train_parameters(self, cfg): assert self.cfg.LOSS_TYPES self.set_optimizer(cfg) self.set_log_data(cfg) self.set_schedulers(cfg) self.set_device() # self.net = nn.DataParallel(self.net).to(self.device) train_iters = 0 best_result = 0 if self.cfg.EVALUATE and self.cfg.SLIDE_WINDOWS: self.prediction_matrix = torch.zeros( self.batch_size_val, self.cfg.NUM_CLASSES, self.cfg.BASE_SIZE[0], self.cfg.BASE_SIZE[1]).to(self.device) self.count_crop_matrix = torch.zeros( self.batch_size_val, 1, self.cfg.BASE_SIZE[0], self.cfg.BASE_SIZE[1]).to(self.device) if cfg.INFERENCE: self.phase = 'test' start_time = time.time() print('Inferencing model...') self.evaluate() self.print_evaluate_results() save_dir = './images/' # np.savetxt(save_dir+'/target.txt',self.target_index_all) # np.savetxt(save_dir+'/pred.txt',self.pred_index_all) np.savetxt(save_dir + '/class_baseline.txt', self.accuracy_class) # self.target_index_all=np.loadtxt(save_dir+'/target.txt') # self.pred_index_all=np.loadtxt(save_dir+'/pred.txt') # from sklearn.metrics import confusion_matrix # cm=confusion_matrix(self.target_index_all,self.pred_index_all) util.plot_confusion_matrix(self.target_index_all, self.pred_index_all, self.val_loader.dataset.classes) print('Evaluation Time: {0} sec'.format(time.time() - start_time)) # self.write_loss(phase=self.phase) return if cfg.MULTIPROCESSING_DISTRIBUTED: total_epoch = int(cfg.NITER_TOTAL / math.ceil( (self.train_image_num / (cfg.BATCH_SIZE_TRAIN * len(cfg.GPU_IDS))))) else: total_epoch = int(cfg.NITER_TOTAL / math.ceil( (self.train_image_num / cfg.BATCH_SIZE_TRAIN))) print('total epoch:{0}, total iters:{1}'.format( total_epoch, cfg.NITER_TOTAL)) for epoch in range(cfg.START_EPOCH, total_epoch + 1): if train_iters > cfg.NITER_TOTAL: break if cfg.MULTIPROCESSING_DISTRIBUTED: cfg.train_sampler.set_epoch(epoch) self.print_lr() # current_lr = util.poly_learning_rate(cfg.LR, train_iters, cfg.NITER_TOTAL, power=0.8) # if cfg.LR_POLICY != 'plateau': # self.update_learning_rate(step=train_iters) # else: # self.update_learning_rate(val=self.loss_meters['VAL_CLS_LOSS'].avg) self.fake_image_num = 0 start_time = time.time() self.phase = 'train' self.net.train() # reset Averagemeters on each epoch for key in self.loss_meters: self.loss_meters[key].reset() iters = 0 print('gpu_ids:', cfg.GPU_IDS) print('# Training images num = {0}'.format(self.train_image_num)) # batch = tqdm(self.train_loader) # for data in batch: for data in self.train_loader: self.set_input(data) train_iters += 1 iters += 1 self._optimize(train_iters) self.update_learning_rate(step=train_iters) # self.val_iou = self.validate(train_iters) # self.write_loss(phase=self.phase, global_step=train_iters) print('log_path:', cfg.LOG_PATH) print('iters in one epoch:', iters) self.write_loss(phase=self.phase, global_step=train_iters) print('Epoch: {epoch}/{total}'.format(epoch=epoch, total=total_epoch)) util.print_current_errors( util.get_current_errors(self.loss_meters, current=False), epoch) print('Training Time: {0} sec'.format(time.time() - start_time)) # if cfg.EVALUATE: if (epoch % self.cfg.EVALUATE_FREQ == 0 or epoch > total_epoch - 10 or epoch == total_epoch) and cfg.EVALUATE: print('# Cls val images num = {0}'.format(self.val_image_num)) self.evaluate() self.print_evaluate_results() self.write_loss(phase=self.phase, global_step=train_iters) # save best model if cfg.SAVE_BEST and epoch > total_epoch - 10: # save model for key in self.loss_meters: if 'MEAN' in key and self.loss_meters[key].val > 0: if self.loss_meters[key].val > best_result: best_result = self.loss_meters[key].val model_filename = 'best_{0}.pth'.format( self.cfg.LOG_NAME) print('best epoch / iters are {0}/{1}'.format( epoch, iters)) self.save_checkpoint(model_filename) print('best {0} is {1}, epoch is {2}, iters {3}'. format(key, best_result, epoch, iters)) print('End of iter {0} / {1} \t ' 'Time Taken: {2} sec'.format(train_iters, cfg.NITER_TOTAL, time.time() - start_time)) print('-' * 80) def evaluate(self): if not self.cfg.SLIDE_WINDOWS: self.validate() else: self.validate_slide_window() def save_best(self, best_result, epoch=None, iters=None): if self.cfg.TASK_TYPE == 'segmentation': result = self.loss_meters['VAL_CLS_MEAN_IOU'].val elif self.cfg.TASK_TYPE == 'recognition': result = self.loss_meters['VAL_CLS_MEAN_ACC'].val is_best = result > best_result best_result = max(result, best_result) if is_best: model_filename = 'best_{0}.pth'.format(self.cfg.LOG_NAME) print('best epoch / iters are {0}/{1}'.format(epoch, iters)) self.save_checkpoint(model_filename) print('best miou is {0}, epoch is {1}, iters {2}'.format( best_result, epoch, iters)) def print_evaluate_results(self): if self.cfg.TASK_TYPE == 'segmentation': print('MIOU: {miou}, mAcc: {macc}, acc: {acc}'.format( miou=self.loss_meters['VAL_CLS_MEAN_IOU'].val * 100, macc=self.loss_meters['VAL_CLS_MEAN_ACC'].val * 100, acc=self.loss_meters['VAL_CLS_ACC'].val * 100)) elif self.cfg.TASK_TYPE == 'recognition': print('Mean Acc Top1 <{mean_acc:.3f}> '.format( mean_acc=self.loss_meters['VAL_CLS_MEAN_ACC'].val * 100)) elif self.cfg.TASK_TYPE == 'infomax': print('Mean Acc Top1 MLP: <{mean_acc:.3f}> '.format( mean_acc=self.loss_meters['VAL_CLS_MEAN_ACC_MLP'].val * 100)) def _forward(self, cal_loss=True): if self.cfg.USE_FAKE_DATA: with torch.no_grad(): result_sample = self.sample_model(source=self.source_modal, target=None, label=None, phase=self.phase, cal_loss=False) fake_imgs = result_sample['gen_img'] input_num = len(fake_imgs) indexes = [i for i in range(input_num)] random_index = random.sample( indexes, int(len(fake_imgs) * self.cfg.FAKE_DATA_RATE)) for i in random_index: self.source_modal[i, :] = fake_imgs.data[i, :] self.result = self.net(source=self.source_modal, target=self.target_modal, label=self.label, phase=self.phase, cal_loss=cal_loss) def _construct_loss(self, iter): loss_total = torch.zeros(1).to(self.device) if 'CLS' in self.cfg.LOSS_TYPES: if self.cfg.MULTIPROCESSING_DISTRIBUTED: cls_loss = self.result['loss_cls'] * self.cfg.ALPHA_CLS loss_total += cls_loss dist.all_reduce(cls_loss) if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA: cls_loss_compl = self.result[ 'loss_cls_compl'] * self.cfg.ALPHA_CLS loss_total += cls_loss_compl # cls_loss_fuse = self.result['loss_cls_fuse'] * self.cfg.ALPHA_CLS # loss_total += cls_loss_fuse dist.all_reduce(cls_loss_compl) else: cls_loss = self.result['loss_cls'].mean() * self.cfg.ALPHA_CLS loss_total += cls_loss if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA: cls_loss_compl = self.result['loss_cls_compl'].mean( ) * self.cfg.ALPHA_CLS loss_total += cls_loss_compl # cls_loss_fuse = self.result['loss_cls_fuse'].mean() * self.cfg.ALPHA_CLS # loss_total += cls_loss_fuse self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss.item(), self.batch_size) if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA: self.loss_meters['TRAIN_CLS_LOSS_COMPL'].update( cls_loss_compl.item(), self.batch_size) # self.loss_meters['TRAIN_CLS_LOSS_FUSE'].update(cls_loss_fuse.item(), self.batch_size) # ) content supervised if 'SEMANTIC' in self.cfg.LOSS_TYPES: if self.cfg.MULTI_MODAL: self.gen = [self.result['gen_img_1'], self.result['gen_img_2']] else: self.gen = self.result['gen_img'] decay_coef = 1 # decay_coef = (iters / self.cfg.NITER_TOTAL) # small to big # decay_coef = max(0, (self.cfg.NITER_TOTAL - iter) / self.cfg.NITER_TOTAL) # big to small if self.cfg.MULTIPROCESSING_DISTRIBUTED: content_loss = self.result[ 'loss_content'] * self.cfg.ALPHA_CONTENT * decay_coef loss_total += content_loss dist.all_reduce(content_loss) # content_loss = content_loss.detach() / self.batch_size else: content_loss = self.result['loss_content'].mean( ) * self.cfg.ALPHA_CONTENT * decay_coef loss_total += content_loss self.loss_meters['TRAIN_SEMANTIC_LOSS'].update( content_loss.item(), self.batch_size) if 'PIX2PIX' in self.cfg.LOSS_TYPES: if self.cfg.MULTI_MODAL: self.gen = [self.result['gen_img_1'], self.result['gen_img_2']] else: self.gen = self.result['gen_img'] decay_coef = 1 if self.cfg.MULTIPROCESSING_DISTRIBUTED: pix2pix_loss = self.result[ 'loss_pix2pix'] * self.cfg.ALPHA_PIX2PIX * decay_coef loss_total += pix2pix_loss else: pix2pix_loss = self.result['loss_pix2pix'].mean( ) * self.cfg.ALPHA_PIX2PIX * decay_coef loss_total += pix2pix_loss self.loss_meters['TRAIN_PIX2PIX_LOSS'].update( pix2pix_loss, self.batch_size) if 'GAN' in self.cfg.LOSS_TYPES: real_g = self.result['gen_img'] # real_g = torch.cat((self.result['gen_img'], self.source_modal), 1) if self.cfg.MULTIPROCESSING_DISTRIBUTED: loss_gan_g = self.discriminator(real_g, True) * self.cfg.ALPHA_GAN else: loss_gan_g = self.discriminator( real_g, True).mean() * self.cfg.ALPHA_GAN self.loss_meters['TRAIN_GAN_G_LOSS'].update( loss_gan_g.item(), self.batch_size) loss_total += loss_gan_g return loss_total def set_log_data(self, cfg): self.loss_meters = defaultdict() self.log_keys = [ 'TRAIN_GAN_G_LOSS', 'TRAIN_GAN_D_LOSS', 'TRAIN_SEMANTIC_LOSS', # semantic 'TRAIN_PIX2PIX_LOSS', 'TRAIN_CLS_ACC', 'VAL_CLS_ACC', # classification 'TRAIN_CLS_LOSS', 'TRAIN_CLS_MEAN_IOU', 'VAL_CLS_LOSS', 'VAL_CLS_MEAN_IOU', 'VAL_CLS_MEAN_ACC', 'INTERSECTION', 'UNION', 'LABEL', 'TRAIN_CLS_LOSS_COMPL', 'TRAIN_CLS_LOSS_FUSE' ] for item in self.log_keys: self.loss_meters[item] = AverageMeter() def set_optimizer(self, cfg): self.optimizers = [] # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=cfg.LR, betas=(0.5, 0.999)) if self.params_list: self.optimizer = torch.optim.Adam(self.params_list, lr=cfg.LR, betas=(0.5, 0.999)) else: self.optimizer = torch.optim.Adam(self.net.parameters(), lr=cfg.LR, betas=(0.5, 0.999)) # self.optimizer = torch.optim.SGD(self.net.parameters(), lr=cfg.LR, momentum=cfg.MOMENTUM, weight_decay=cfg.WEIGHT_DECAY) if cfg.MULTIPROCESSING_DISTRIBUTED: if cfg.USE_APEX: self.net, self.optimizer = apex.amp.initialize( self.net.cuda(), self.optimizer, opt_level=cfg.opt_level) self.net = DDP(self.net) else: self.net = torch.nn.parallel.DistributedDataParallel( self.net.cuda(), device_ids=[cfg.gpu]) self.optimizers.append(self.optimizer) if 'GAN' in self.cfg.LOSS_TYPES: self.optimizer_d = torch.optim.SGD(self.discriminator.parameters(), lr=cfg.LR, momentum=0.9, weight_decay=0.0005) if cfg.MULTIPROCESSING_DISTRIBUTED: if cfg.USE_APEX: self.discriminator, self.optimizer_d = apex.amp.initialize( self.discriminator.cuda(), self.optimizer_d, opt_level=cfg.opt_level) self.discriminator = DDP(self.discriminator) else: self.discriminator = torch.nn.parallel.DistributedDataParallel( self.discriminator.cuda(), device_ids=[cfg.gpu]) self.optimizers.append(self.optimizer_d) def validate_slide_window(self): self.net.eval() self.phase = 'test' intersection_meter = self.loss_meters['INTERSECTION'] union_meter = self.loss_meters['UNION'] target_meter = self.loss_meters['LABEL'] print('testing with sliding windows...') num_images = 0 # batch = tqdm(self.val_loader) # for data in batch: for data in self.val_loader: self.set_input(data) num_images += self.batch_size pred = util.slide_cal(model=self.net, image=self.source_modal, crop_size=self.cfg.FINE_SIZE, prediction_matrix=self.prediction_matrix[ 0:self.batch_size, :, :, :], count_crop_matrix=self.count_crop_matrix[ 0:self.batch_size, :, :, :]) self.pred = pred.data.max(1)[1] intersection, union, label = util.intersectionAndUnionGPU( self.pred, self.label, self.cfg.NUM_CLASSES) if self.cfg.MULTIPROCESSING_DISTRIBUTED: dist.all_reduce(intersection), dist.all_reduce( union), dist.all_reduce(label) intersection, union, label = intersection.cpu().numpy(), union.cpu( ).numpy(), label.cpu().numpy() intersection_meter.update(intersection, self.batch_size) union_meter.update(union, self.batch_size) target_meter.update(label, self.batch_size) iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) mIoU = np.mean(iou_class) mAcc = np.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) self.loss_meters['VAL_CLS_ACC'].update(allAcc) self.loss_meters['VAL_CLS_MEAN_ACC'].update(mAcc) self.loss_meters['VAL_CLS_MEAN_IOU'].update(mIoU) def validate(self): self.phase = 'test' # switch to evaluate mode self.net.eval() intersection_meter = self.loss_meters['INTERSECTION'] union_meter = self.loss_meters['UNION'] target_meter = self.loss_meters['LABEL'] if self.cfg.USE_FAKE_DATA or self.cfg.INFERENCE: self.pred_index_all = [] self.target_index_all = [] with torch.no_grad(): # batch_index = int(self.val_image_num / cfg.BATCH_SIZE) # random_id = random.randint(0, batch_index) # batch = tqdm(self.val_loader) # for data in batch: for i, data in enumerate(self.val_loader): self.set_input(data) self._forward(cal_loss=False) if self.cfg.INFERENCE: self._process_fc() self.pred = self.result['cls'].data.max(1)[1] intersection, union, label = util.intersectionAndUnionGPU( self.pred, self.label, self.cfg.NUM_CLASSES) if self.cfg.MULTIPROCESSING_DISTRIBUTED: dist.all_reduce(intersection), dist.all_reduce( union), dist.all_reduce(label) intersection, union, label = intersection.cpu().numpy( ), union.cpu().numpy(), label.cpu().numpy() intersection_meter.update(intersection, self.batch_size) union_meter.update(union, self.batch_size) target_meter.update(label, self.batch_size) # Mean ACC # self._cal_mean_acc(self.cfg,self.val_loader) accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) self.accuracy_class = accuracy_class mAcc = np.mean(accuracy_class) allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) self.loss_meters['VAL_CLS_ACC'].update(allAcc) self.loss_meters['VAL_CLS_MEAN_ACC'].update(mAcc) if self.cfg.TASK_TYPE == 'segmentation': iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) mIoU = np.mean(iou_class) self.loss_meters['VAL_CLS_MEAN_IOU'].update(mIoU) def _process_fc(self): # dist.all_reduce(self.result['cls']) _, index = self.result['cls'].data.topk(1, 1, largest=True) self.pred_index_all.extend(list(index.cpu().numpy())) self.target_index_all.extend(list(self._label.numpy())) def _cal_mean_acc(self, cfg, data_loader): mean_acc = util.mean_acc(np.array(self.target_index_all), np.array(self.pred_index_all), cfg.NUM_CLASSES, data_loader.dataset.classes) return mean_acc def write_loss(self, phase, global_step=1): loss_types = self.cfg.LOSS_TYPES task = self.cfg.TASK_TYPE if self.phase == 'train': label_show = self.label.data.cpu().numpy() else: label_show = np.uint8(self.label.data.cpu()) source_modal_show = self.source_modal target_modal_show = self.target_modal if phase == 'train': self.writer.add_image(task + '/Train_image', torchvision.utils.make_grid( source_modal_show[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) self.writer.add_scalar(task + '/LR', self.optimizer.param_groups[0]['lr'], global_step=global_step) if 'CLS' in loss_types: self.writer.add_scalar(task + '/TRAIN_CLS_LOSS', self.loss_meters['TRAIN_CLS_LOSS'].avg, global_step=global_step) if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA: self.writer.add_scalar( task + '/TRAIN_CLS_LOSS_COMPL', self.loss_meters['TRAIN_CLS_LOSS_COMPL'].avg, global_step=global_step) self.writer.add_image( task + '/Compl_image', torchvision.utils.make_grid( self.result['compl_source'][:6].clone().cpu().data, 3, normalize=True), global_step=global_step) # self.writer.add_scalar('TRAIN_CLS_ACC', self.loss_meters['TRAIN_CLS_ACC'].avg*100.0, # global_step=global_step) # self.writer.add_scalar('TRAIN_CLS_MEAN_IOU', float(self.train_iou.mean())*100.0, # global_step=global_step) if self.trans and not self.cfg.MULTI_MODAL: if 'SEMANTIC' in self.cfg.LOSS_TYPES: self.writer.add_scalar( task + '/TRAIN_SEMANTIC_LOSS', self.loss_meters['TRAIN_SEMANTIC_LOSS'].avg, global_step=global_step) if 'PIX2PIX' in self.cfg.LOSS_TYPES: self.writer.add_scalar( task + '/TRAIN_PIX2PIX_LOSS', self.loss_meters['TRAIN_PIX2PIX_LOSS'].avg, global_step=global_step) self.writer.add_image(task + '/Train_gen', torchvision.utils.make_grid( self.gen.data[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) self.writer.add_image( task + '/Train_image', torchvision.utils.make_grid( source_modal_show[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) # if isinstance(self.target_modal, list): # for i, (gen, target) in enumerate(zip(self.gen, self.target_modal)): # self.writer.add_image('Seg/2_Train_Gen_' + str(self.cfg.FINE_SIZE / pow(2, i)), # torchvision.utils.make_grid(gen[:6].clone().cpu().data, 3, # normalize=True), # global_step=global_step) # self.writer.add_image('Seg/3_Train_Target_' + str(self.cfg.FINE_SIZE / pow(2, i)), # torchvision.utils.make_grid(target[:6].clone().cpu().data, 3, # normalize=True), # global_step=global_step) # else: self.writer.add_image( task + '/Train_target', torchvision.utils.make_grid( target_modal_show[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) if 'CLS' in loss_types and self.cfg.TASK_TYPE == 'segmentation': train_pred = self.result['cls'].data.max(1)[1].cpu().numpy() self.writer.add_image( task + '/Train_predicted', torchvision.utils.make_grid(torch.from_numpy( util.color_label(train_pred[:6], ignore=self.cfg.IGNORE_LABEL, dataset=self.cfg.DATASET)), 3, normalize=True, range=(0, 255)), global_step=global_step) self.writer.add_image( task + '/Train_label', torchvision.utils.make_grid(torch.from_numpy( util.color_label(label_show[:6], ignore=self.cfg.IGNORE_LABEL, dataset=self.cfg.DATASET)), 3, normalize=True, range=(0, 255)), global_step=global_step) elif phase == 'test': self.writer.add_image(task + '/Val_image', torchvision.utils.make_grid( source_modal_show[:6].clone().cpu().data, 3, normalize=True), global_step=global_step) # self.writer.add_image('Seg/Val_image', # torchvision.utils.make_grid(source_modal_show[:6].clone().cpu().data, 3, # normalize=True), global_step=global_step) # # self.writer.add_image('Seg/Val_predicted', # torchvision.utils.make_grid( # torch.from_numpy(util.color_label(self.pred[:6], ignore=self.cfg.IGNORE_LABEL, # dataset=self.cfg.DATASET)), 3, # normalize=True, range=(0, 255)), global_step=global_step) # self.writer.add_image('Seg/Val_label', # torchvision.utils.make_grid(torch.from_numpy( # util.color_label(label_show[:6], ignore=self.cfg.IGNORE_LABEL, # dataset=self.cfg.DATASET)), # 3, normalize=True, range=(0, 255)), # global_step=global_step) if 'compl' in self.cfg.MODEL or self.cfg.USE_COMPL_DATA: self.writer.add_image( task + '/Compl_image', torchvision.utils.make_grid( self.result['compl_source'][:6].clone().cpu().data, 3, normalize=True), global_step=global_step) self.writer.add_scalar(task + '/VAL_CLS_ACC', self.loss_meters['VAL_CLS_ACC'].val * 100.0, global_step=global_step) self.writer.add_scalar(task + '/VAL_CLS_MEAN_ACC', self.loss_meters['VAL_CLS_MEAN_ACC'].val * 100.0, global_step=global_step) if task == 'segmentation': self.writer.add_scalar( task + '/VAL_CLS_MEAN_IOU', self.loss_meters['VAL_CLS_MEAN_IOU'].val * 100.0, global_step=global_step)