def train_net(args, config): # setup logger logger, final_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET, split='train') model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX) if args.log_dir is None: args.log_dir = os.path.join(final_output_path, 'tensorboard_logs') pprint.pprint(args) logger.info('training args:{}\n'.format(args)) pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) # manually set random seed if config.RNG_SEED > -1: np.random.seed(config.RNG_SEED) torch.random.manual_seed(config.RNG_SEED) torch.cuda.manual_seed_all(config.RNG_SEED) # cudnn torch.backends.cudnn.benchmark = False if args.cudnn_off: torch.backends.cudnn.enabled = False if args.dist: model = eval(config.MODULE)(config) local_rank = int(os.environ.get('LOCAL_RANK') or 0) config.GPUS = str(local_rank) torch.cuda.set_device(local_rank) master_address = os.environ['MASTER_ADDR'] master_port = int(os.environ['MASTER_PORT'] or 23456) world_size = int(os.environ['WORLD_SIZE'] or 1) rank = int(os.environ['RANK'] or 0) if args.slurm: distributed.init_process_group(backend='nccl') else: distributed.init_process_group(backend='nccl', init_method='tcp://{}:{}'.format( master_address, master_port), world_size=world_size, rank=rank, group_name='mtorch') print( f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}' ) torch.cuda.set_device(local_rank) config.GPUS = str(local_rank) model = model.cuda() if not config.TRAIN.FP16: model = DDP(model, device_ids=[local_rank], output_device=local_rank) if rank == 0: summary_parameters( model.module if isinstance( model, torch.nn.parallel.DistributedDataParallel) else model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) writer = None if args.log_dir is not None: tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank)) if not os.path.exists(tb_log_dir): os.makedirs(tb_log_dir) writer = SummaryWriter(log_dir=tb_log_dir) train_loader, train_sampler = make_dataloader(config, mode='train', distributed=True, num_replicas=world_size, rank=rank, expose_sampler=True) val_loader = make_dataloader(config, mode='val', distributed=True, num_replicas=world_size, rank=rank) batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult } for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({ 'params': [ p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT]) ] }) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format( config.TRAIN.OPTIMIZER)) total_gpus = world_size else: #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS model = eval(config.MODULE)(config) # import pdb; pdb.set_trace() if config.NETWORK.VLBERT.vlbert_frozen: # freeze all parameters first for p in model.parameters(): p.requires_grad = False # unfreeze the last layer(s) if config.NETWORK.VLBERT.vlbert_unfrozen_layers != 0: for p in model.vlbert.encoder.layer[ -config.NETWORK.VLBERT. vlbert_unfrozen_layers:].parameters(): p.requires_grad = True for p in model.final_mlp.parameters(): p.requires_grad = True if config.NETWORK.USE_SPATIAL_MODEL: for p in model.simple_spatial_model.parameters(): p.requires_grad = True for p in model.spa_fusion_linear.parameters(): p.requires_grad = True for p in model.spa_linear.parameters(): p.requires_grad = True if config.NETWORK.SPA_ONE_MORE_LAYER: for p in model.spa_linear_hidden.parameters(): p.requires_grad = True # If use enhanced image feature if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE: for p in model.vlbert.obj_feat_downsample.parameters(): p.requires_grad = True for p in model.vlbert.obj_feat_batchnorm.parameters(): p.requires_grad = True for p in model.vlbert.lan_img_conv1.parameters(): p.requires_grad = True for p in model.vlbert.lan_img_conv2.parameters(): p.requires_grad = True for p in model.vlbert.lan_img_conv3.parameters(): p.requires_grad = True for p in model.vlbert.lan_img_conv4.parameters(): p.requires_grad = True if config.NETWORK.VLBERT.vlbert_frozen_embedding_LayerNorm: print('freezing embedding_LayerNorm...') for p in model.vlbert.embedding_LayerNorm.parameters(): p.requires_grad = False if config.NETWORK.VLBERT.vlbert_frozen_encoder: print('freezing encoder...') for p in model.vlbert.encoder.parameters(): p.requires_grad = False summary_parameters(model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) num_gpus = len(config.GPUS.split(',')) assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \ "Please use amp.parallel.DistributedDataParallel instead." total_gpus = num_gpus rank = None writer = SummaryWriter( log_dir=args.log_dir) if args.log_dir is not None else None # model if num_gpus > 1: model = torch.nn.DataParallel( model, device_ids=[int(d) for d in config.GPUS.split(',')]).cuda() else: torch.cuda.set_device(int(config.GPUS)) model.cuda() # loader train_loader = make_dataloader(config, mode=config.DATASET.TRAIN_IMAGE_SET, distributed=False) val_loader = make_dataloader(config, mode=config.DATASET.VAL_IMAGE_SET, distributed=False) test_loader = make_dataloader(config, mode=config.DATASET.TEST_IMAGE_SET, distributed=False) train_sampler = None batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance( config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult } for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({ 'params': [ p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT]) ] }) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format( config.TRAIN.OPTIMIZER)) # partial load pretrain state dict if config.NETWORK.PARTIAL_PRETRAIN != "": pretrain_state_dict = torch.load( config.NETWORK.PARTIAL_PRETRAIN, map_location=lambda storage, loc: storage)['state_dict'] prefix_change = [ prefix_change.split('->') for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES ] if len(prefix_change) > 0: pretrain_state_dict_parsed = {} for k, v in pretrain_state_dict.items(): no_match = True for pretrain_prefix, new_prefix in prefix_change: if k.startswith(pretrain_prefix): k = new_prefix + k[len(pretrain_prefix):] pretrain_state_dict_parsed[k] = v no_match = False break if no_match: pretrain_state_dict_parsed[k] = v pretrain_state_dict = pretrain_state_dict_parsed # import pdb; pdb.set_trace() smart_partial_load_model_state_dict(model, pretrain_state_dict) # pretrained classifier if config.NETWORK.CLASSIFIER_PRETRAINED: # false for now print( 'Initializing classifier weight from pretrained word embeddings...' ) for k, v in model.state_dict().items(): if 'word_embeddings.weight' in k: word_embeddings = v.detach().clone() break answers_word_embed = [] for answer in config.PREDICATE_CATEGORIES: a_tokens = train_loader.dataset.tokenizer.tokenize(answer) a_ids = train_loader.dataset.tokenizer.convert_tokens_to_ids( a_tokens) a_word_embed = (torch.stack( [word_embeddings[a_id] for a_id in a_ids], dim=0)).mean(dim=0) answers_word_embed.append(a_word_embed) answers_word_embed_tensor = torch.stack(answers_word_embed, dim=0) for name, module in model.named_modules(): if name.endswith('final_mlp'): module[-1].weight.data = answers_word_embed_tensor.to( device=module[-1].weight.data.device) # metrics train_metrics_list = [ spasen_metrics.Accuracy(allreduce=args.dist, num_replicas=world_size if args.dist else 1) ] val_metrics_list = [ spasen_metrics.Accuracy(allreduce=args.dist, num_replicas=world_size if args.dist else 1) ] for output_name, display_name in config.TRAIN.LOSS_LOGGERS: train_metrics_list.append( spasen_metrics.LossLogger( output_name, display_name=display_name, allreduce=args.dist, num_replicas=world_size if args.dist else 1)) val_metrics_list.append( spasen_metrics.LossLogger( output_name, display_name=display_name, allreduce=args.dist, num_replicas=world_size if args.dist else 1)) train_metrics = CompositeEvalMetric() val_metrics = CompositeEvalMetric() for child_metric in train_metrics_list: train_metrics.add(child_metric) for child_metric in val_metrics_list: val_metrics.add(child_metric) # epoch end callbacks epoch_end_callbacks = [] if (rank is None) or (rank == 0): epoch_end_callbacks = [ Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT) ] validation_monitor = ValidationMonitor( do_validation, val_loader, val_metrics, host_metric_name='Acc', label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH) testing_monitor = ValidationMonitor( do_validation, test_loader, val_metrics, host_metric_name='Acc', label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH, do_test=True) # optimizer initial lr before for group in optimizer.param_groups: group.setdefault('initial_lr', group['lr']) # resume/auto-resume if rank is None or rank == 0: smart_resume(model, optimizer, validation_monitor, config, model_prefix, logger) if args.dist: begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda() distributed.broadcast(begin_epoch, src=0) config.TRAIN.BEGIN_EPOCH = begin_epoch.item() # batch end callbacks batch_size = len(config.GPUS.split(',')) * config.TRAIN.BATCH_IMAGES batch_end_callbacks = [ Speedometer(batch_size, config.LOG_FREQUENT, batches_per_epoch=len(train_loader), epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) ] # setup lr step and lr scheduler if config.TRAIN.LR_SCHEDULE == 'plateau': print("Warning: not support resuming on plateau lr schedule!") lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.TRAIN.LR_FACTOR, patience=1, verbose=True, threshold=1e-4, threshold_mode='rel', cooldown=2, min_lr=0, eps=1e-8) elif config.TRAIN.LR_SCHEDULE == 'triangle': lr_scheduler = WarmupLinearSchedule( optimizer, config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, t_total=int(config.TRAIN.END_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS), last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) elif config.TRAIN.LR_SCHEDULE == 'step': lr_iters = [ int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) for epoch in config.TRAIN.LR_STEP ] lr_scheduler = WarmupMultiStepLR( optimizer, milestones=lr_iters, gamma=config.TRAIN.LR_FACTOR, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, warmup_method=config.TRAIN.WARMUP_METHOD, last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) else: raise ValueError("Not support lr schedule: {}.".format( config.TRAIN.LR_SCHEDULE)) # broadcast parameter and optimizer state from rank 0 before training start if args.dist: for v in model.state_dict().values(): distributed.broadcast(v, src=0) # for v in optimizer.state_dict().values(): # distributed.broadcast(v, src=0) best_epoch = torch.tensor(validation_monitor.best_epoch).cuda() best_val = torch.tensor(validation_monitor.best_val).cuda() distributed.broadcast(best_epoch, src=0) distributed.broadcast(best_val, src=0) validation_monitor.best_epoch = best_epoch.item() validation_monitor.best_val = best_val.item() # apex: amp fp16 mixed-precision training if config.TRAIN.FP16: # model.apply(bn_fp16_half_eval) model, optimizer = amp.initialize( model, optimizer, opt_level='O2', keep_batchnorm_fp32=False, loss_scale=config.TRAIN.FP16_LOSS_SCALE, min_loss_scale=32.0) if args.dist: model = Apex_DDP(model, delay_allreduce=True) train(model, optimizer, lr_scheduler, train_loader, train_sampler, train_metrics, config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH, logger, rank=rank, batch_end_callbacks=batch_end_callbacks, epoch_end_callbacks=epoch_end_callbacks, writer=writer, validation_monitor=validation_monitor, fp16=config.TRAIN.FP16, clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM, gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS, testing_monitor=testing_monitor) return rank, model
def train_net(args, config): # setup logger logger, final_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET[0].TRAIN_IMAGE_SET if isinstance(config.DATASET, list) else config.DATASET.TRAIN_IMAGE_SET, split='train') model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX) if args.log_dir is None: args.log_dir = os.path.join(final_output_path, 'tensorboard_logs') pprint.pprint(args) logger.info('training args:{}\n'.format(args)) pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) # manually set random seed if config.RNG_SEED > -1: random.seed(config.RNG_SEED) np.random.seed(config.RNG_SEED) torch.random.manual_seed(config.RNG_SEED) torch.cuda.manual_seed_all(config.RNG_SEED) # cudnn torch.backends.cudnn.benchmark = False if args.cudnn_off: torch.backends.cudnn.enabled = False if args.dist: model = eval(config.MODULE)(config) local_rank = int(os.environ.get('LOCAL_RANK') or 0) config.GPUS = str(local_rank) torch.cuda.set_device(local_rank) master_address = os.environ['MASTER_ADDR'] # master_port = int(os.environ['MASTER_PORT'] or 23456) # master_port = int(9997) master_port = int(9995) world_size = int(os.environ['WORLD_SIZE'] or 1) rank = int(os.environ['RANK'] or 0) if args.slurm: distributed.init_process_group(backend='nccl') else: distributed.init_process_group( backend='nccl', init_method='tcp://{}:{}'.format(master_address, master_port), world_size=world_size, rank=rank, group_name='mtorch') print(f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}') torch.cuda.set_device(local_rank) config.GPUS = str(local_rank) model = model.cuda() if not config.TRAIN.FP16: model = DDP(model, device_ids=[local_rank], output_device=local_rank) if rank == 0: summary_parameters(model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) writer = None if args.log_dir is not None: tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank)) if not os.path.exists(tb_log_dir): os.makedirs(tb_log_dir) writer = SummaryWriter(log_dir=tb_log_dir) if isinstance(config.DATASET, list): train_loaders_and_samplers = make_dataloaders(config, mode='train', distributed=True, num_replicas=world_size, rank=rank, expose_sampler=True) val_loaders = make_dataloaders(config, mode='val', distributed=True, num_replicas=world_size, rank=rank) train_loader = MultiTaskDataLoader([loader for loader, _ in train_loaders_and_samplers]) val_loader = MultiTaskDataLoader(val_loaders) train_sampler = train_loaders_and_samplers[0][1] else: train_loader, train_sampler = make_dataloader(config, mode='train', distributed=True, num_replicas=world_size, rank=rank, expose_sampler=True) val_loader = make_dataloader(config, mode='val', distributed=True, num_replicas=world_size, rank=rank) batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult} for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])]}) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format(config.TRAIN.OPTIMIZER)) total_gpus = world_size else: #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS model = eval(config.MODULE)(config) summary_parameters(model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) num_gpus = len(config.GPUS.split(',')) assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \ "Please use amp.parallel.DistributedDataParallel instead." total_gpus = num_gpus rank = None writer = SummaryWriter(log_dir=args.log_dir) if args.log_dir is not None else None # model if num_gpus > 1: model = torch.nn.DataParallel(model, device_ids=[int(d) for d in config.GPUS.split(',')]).cuda() else: torch.cuda.set_device(int(config.GPUS)) model.cuda() # loader if isinstance(config.DATASET, list): train_loaders = make_dataloaders(config, mode='train', distributed=False) val_loaders = make_dataloaders(config, mode='val', distributed=False) train_loader = MultiTaskDataLoader(train_loaders) val_loader = MultiTaskDataLoader(val_loaders) else: train_loader = make_dataloader(config, mode='train', distributed=False) val_loader = make_dataloader(config, mode='val', distributed=False) train_sampler = None batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult} for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT])]}) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format(config.TRAIN.OPTIMIZER)) # partial load pretrain state dict if config.NETWORK.PARTIAL_PRETRAIN != "": pretrain_state_dict = torch.load(config.NETWORK.PARTIAL_PRETRAIN, map_location=lambda storage, loc: storage)['state_dict'] prefix_change = [prefix_change.split('->') for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES] if len(prefix_change) > 0: pretrain_state_dict_parsed = {} for k, v in pretrain_state_dict.items(): no_match = True for pretrain_prefix, new_prefix in prefix_change: if k.startswith(pretrain_prefix): k = new_prefix + k[len(pretrain_prefix):] pretrain_state_dict_parsed[k] = v no_match = False break if no_match: pretrain_state_dict_parsed[k] = v pretrain_state_dict = pretrain_state_dict_parsed # FM edit: introduce alternative initialisations if config.NETWORK.INITIALISATION=='hybrid': smart_hybrid_partial_load_model_state_dict(model, pretrain_state_dict) elif config.NETWORK.INITIALISATION=='skip': smart_skip_partial_load_model_state_dict(model, pretrain_state_dict) else: smart_partial_load_model_state_dict(model, pretrain_state_dict) # metrics metric_kwargs = {'allreduce': args.dist, 'num_replicas': world_size if args.dist else 1} train_metrics_list = [] val_metrics_list = [] if config.NETWORK.WITH_REL_LOSS: train_metrics_list.append(retrieval_metrics.RelationshipAccuracy(**metric_kwargs)) val_metrics_list.append(retrieval_metrics.RelationshipAccuracy(**metric_kwargs)) if config.NETWORK.WITH_MLM_LOSS: if config.MODULE == 'ResNetVLBERTForPretrainingMultitask': train_metrics_list.append(retrieval_metrics.MLMAccuracyWVC(**metric_kwargs)) train_metrics_list.append(retrieval_metrics.MLMAccuracyAUX(**metric_kwargs)) val_metrics_list.append(retrieval_metrics.MLMAccuracyWVC(**metric_kwargs)) val_metrics_list.append(retrieval_metrics.MLMAccuracyAUX(**metric_kwargs)) else: train_metrics_list.append(retrieval_metrics.MLMAccuracy(**metric_kwargs)) val_metrics_list.append(retrieval_metrics.MLMAccuracy(**metric_kwargs)) if config.NETWORK.WITH_MVRC_LOSS: train_metrics_list.append(retrieval_metrics.MVRCAccuracy(**metric_kwargs)) val_metrics_list.append(retrieval_metrics.MVRCAccuracy(**metric_kwargs)) for output_name, display_name in config.TRAIN.LOSS_LOGGERS: train_metrics_list.append(retrieval_metrics.LossLogger(output_name, display_name=display_name, **metric_kwargs)) val_metrics_list.append(retrieval_metrics.LossLogger(output_name, display_name=display_name, **metric_kwargs)) train_metrics = CompositeEvalMetric() val_metrics = CompositeEvalMetric() for child_metric in train_metrics_list: train_metrics.add(child_metric) for child_metric in val_metrics_list: val_metrics.add(child_metric) # epoch end callbacks epoch_end_callbacks = [] if (rank is None) or (rank == 0): epoch_end_callbacks = [Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT)] host_metric_name = 'MLMAcc' if not config.MODULE == 'ResNetVLBERTForPretrainingMultitask' else 'MLMAccWVC' validation_monitor = ValidationMonitor(do_validation, val_loader, val_metrics, host_metric_name=host_metric_name) # optimizer initial lr before for group in optimizer.param_groups: group.setdefault('initial_lr', group['lr']) # resume/auto-resume if rank is None or rank == 0: smart_resume(model, optimizer, validation_monitor, config, model_prefix, logger) if args.dist: begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda() distributed.broadcast(begin_epoch, src=0) config.TRAIN.BEGIN_EPOCH = begin_epoch.item() # batch end callbacks batch_size = len(config.GPUS.split(',')) * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) batch_end_callbacks = [Speedometer(batch_size, config.LOG_FREQUENT, batches_per_epoch=len(train_loader), epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH)] # setup lr step and lr scheduler if config.TRAIN.LR_SCHEDULE == 'plateau': print("Warning: not support resuming on plateau lr schedule!") lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=config.TRAIN.LR_FACTOR, patience=1, verbose=True, threshold=1e-4, threshold_mode='rel', cooldown=2, min_lr=0, eps=1e-8) elif config.TRAIN.LR_SCHEDULE == 'triangle': lr_scheduler = WarmupLinearSchedule(optimizer, config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, t_total=int(config.TRAIN.END_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS), last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) elif config.TRAIN.LR_SCHEDULE == 'step': lr_iters = [int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) for epoch in config.TRAIN.LR_STEP] lr_scheduler = WarmupMultiStepLR(optimizer, milestones=lr_iters, gamma=config.TRAIN.LR_FACTOR, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, warmup_method=config.TRAIN.WARMUP_METHOD, last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) else: raise ValueError("Not support lr schedule: {}.".format(config.TRAIN.LR_SCHEDULE)) # broadcast parameter and optimizer state from rank 0 before training start if args.dist: for v in model.state_dict().values(): distributed.broadcast(v, src=0) # for v in optimizer.state_dict().values(): # distributed.broadcast(v, src=0) best_epoch = torch.tensor(validation_monitor.best_epoch).cuda() best_val = torch.tensor(validation_monitor.best_val).cuda() distributed.broadcast(best_epoch, src=0) distributed.broadcast(best_val, src=0) validation_monitor.best_epoch = best_epoch.item() validation_monitor.best_val = best_val.item() # apex: amp fp16 mixed-precision training if config.TRAIN.FP16: # model.apply(bn_fp16_half_eval) model, optimizer = amp.initialize(model, optimizer, opt_level='O2', keep_batchnorm_fp32=False, loss_scale=config.TRAIN.FP16_LOSS_SCALE, max_loss_scale=128.0, min_loss_scale=128.0) if args.dist: model = Apex_DDP(model, delay_allreduce=True) train(model, optimizer, lr_scheduler, train_loader, train_sampler, train_metrics, config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH, logger, rank=rank, batch_end_callbacks=batch_end_callbacks, epoch_end_callbacks=epoch_end_callbacks, writer=writer, validation_monitor=validation_monitor, fp16=config.TRAIN.FP16, clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM, gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS) return rank, model
def train_net(args, config): np.random.seed(config.RNG_SEED) logger, final_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET) prefix = os.path.join(final_output_path, config.MODEL_PREFIX) # load symbol current_path = os.path.abspath(os.path.dirname(__file__)) shutil.copy2(os.path.join(current_path, '../modules', config.MODULE + '.py'), os.path.join(final_output_path, config.MODULE + '.py')) net = eval(config.MODULE + '.' + config.MODULE)(config) # setup multi-gpu gpu_num = len(config.GPUS) # print config pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) # prepare dataset train_set = eval(config.DATASET.DATASET)(config.DATASET.TRAIN_IMAGE_SET, config.DATASET.ROOT_PATH, config.DATASET.DATASET_PATH, config.TRAIN.SCALES) test_set = eval(config.DATASET.DATASET)(config.DATASET.TEST_IMAGE_SET, config.DATASET.ROOT_PATH, config.DATASET.DATASET_PATH, config.TEST.SCALES) train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size = config.TRAIN.BATCH_IMAGES_PER_GPU * gpu_num, shuffle=config.TRAIN.SHUFFLE, num_workers= config.NUM_WORKER_PER_GPU * gpu_num) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size = config.TEST.BATCH_IMAGES_PER_GPU * gpu_num, shuffle=False, num_workers= config.NUM_WORKER_PER_GPU * gpu_num) # init parameters if config.TRAIN.RESUME: print(('continue training from ', config.TRAIN.BEGIN_EPOCH)) # load model model_filename = '{}-{:04d}.model'.format(prefix, config.TRAIN.BEGIN_EPOCH-1) check_point = torch.load(model_filename) net.load_state_dict(check_point['state_dict']) optimizer.load_state_dict(check_point['opotimizer']) else: pass # setup metrices train_pred_names = net.get_pred_names(is_train=True) train_label_names = net.get_label_names(is_train=True) train_metrics = CompositeEvalMetric() train_metrics.add(cls_metrics.AccMetric(train_pred_names, train_label_names)) val_pred_names = net.get_pred_names(is_train=False) val_label_names = net.get_label_names(is_train=False) val_metrics = CompositeEvalMetric() val_metrics.add(cls_metrics.AccMetric(val_pred_names, val_label_names)) # setup callback batch_end_callback = [Speedometer(config.TRAIN.BATCH_IMAGES_PER_GPU * gpu_num, frequent=config.LOG_FREQUENT)] epoch_end_callback = [Checkpoint(os.path.join(final_output_path, config.MODEL_PREFIX)), ValidationMonitor(do_validation, test_loader, val_metrics)] # set up optimizer optimizer = optim.SGD(net.parameters(), lr=config.TRAIN.LR, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD, nesterov=True) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR) # set up running devices net.cuda() net = torch.nn.DataParallel(net, device_ids=config.GPUS) # train train(net, optimizer=optimizer, lr_scheduler = scheduler, train_loader=train_loader, metrics=train_metrics, config=config, logger=logger, batch_end_callbacks=batch_end_callback, epoch_end_callbacks=epoch_end_callback)
def train_net(args, config): # manually set random seed if config.RNG_SEED > -1: np.random.seed(config.RNG_SEED) torch.manual_seed(config.RNG_SEED) torch.random.manual_seed(config.RNG_SEED) torch.cuda.manual_seed_all(config.RNG_SEED) random.seed(config.RNG_SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # cudnn torch.backends.cudnn.benchmark = False if args.cudnn_off: torch.backends.cudnn.enabled = False # parallel: distributed training for utilising multiple GPUs if args.dist: # set up the environment local_rank = int(os.environ.get('LOCAL_RANK') or 0) config.GPUS = str(local_rank) torch.cuda.set_device(local_rank) master_address = os.environ['MASTER_ADDR'] master_port = int(os.environ['MASTER_PORT'] or 23456) world_size = int(os.environ['WORLD_SIZE'] or 1) rank = int(os.environ['RANK'] or 0) # initialize process group distributed.init_process_group(backend='nccl', init_method='tcp://{}:{}'.format( master_address, master_port), world_size=world_size, rank=rank, group_name='mtorch') print( f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}' ) # set cuda devices torch.cuda.set_device(local_rank) config.GPUS = str(local_rank) # initialize the model and put it to GPU model = eval(config.MODULE)(config=config.NETWORK) model = model.cuda() # wrap the model using torch distributed data parallel model = DDP(model, device_ids=[local_rank], output_device=local_rank) # Check if the model requires policy network if config.NETWORK.TRAINING_STRATEGY in PolicyVec: policy_model = eval( config.POLICY_MODULE)(config=config.POLICY.NETWORK) policy_model = policy_model.cuda() # wrap in DDP policy_model = DDP(policy_model, device_ids=[local_rank], output_device=local_rank) # Check if the strategy is to train a knowledge distillation model if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': # initialize the teacher model teacher_model = eval(config.TEACHER.MODULE)(config=config.TEACHER) teacher_model = teacher_model.cuda() # wrap in DDP teacher_model = DDP(policy_model, device_ids=[local_rank], output_device=local_rank) # summarize the model if rank == 0: print("summarizing the main network") summary_parameters(model) if config.NETWORK.TRAINING_STRATEGY in PolicyVec: print("summarizing the policy network") summary_parameters(policy_model) if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': print("Summarizing the teacher model") summary_parameters(policy_model) # dataloaders for training, val and test set train_loader = make_dataloader(config, mode='train', distributed=True, num_replicas=world_size, rank=rank) val_loader = make_dataloader(config, mode='val', distributed=True, num_replicas=world_size, rank=rank) else: # set CUDA device in env variables config.GPUS = [*range(len( (config.GPUS).split(',')))] if args.data_parallel else str(0) print(f"config.GPUS = {config.GPUS}") # initialize the model and put is to GPU model = eval(config.MODULE)(config=config.NETWORK) # check for policy model if config.NETWORK.TRAINING_STRATEGY in PolicyVec: policy_model = eval( config.POLICY_MODULE)(config=config.POLICY.NETWORK) policy_model = policy_model.cuda() # Check if the strategy is to train a knowledge distillation model if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': # initialize the teacher model teacher_model = eval(config.TEACHER.MODULE)(config=config.TEACHER) teacher_model = teacher_model.cuda() if args.data_parallel: model = model.cuda() model = nn.DataParallel(model, device_ids=config.GPUS) if config.NETWORK.TRAINING_STRATEGY in PolicyVec: policy_model = nn.DataParallel(policy_model, device_ids=config.GPUS) if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': teacher_model = nn.DataParallel(teacher_model, device_ids=config.GPUS) else: torch.cuda.set_device(0) model = model.cuda() if config.NETWORK.TRAINING_STRATEGY in PolicyVec: policy_model = policy_model.cuda() if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': teacher_model = teacher_model.cuda() # summarize the model print("summarizing the model") summary_parameters(model) if config.NETWORK.TRAINING_STRATEGY in PolicyVec: print("Summarizing the policy model") summary_parameters(policy_model) if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': print("Summarizing the teacher model") summary_parameters(teacher_model) # dataloaders for training and test set train_loader = make_dataloader(config, mode='train', distributed=False) val_loader = make_dataloader(config, mode='val', distributed=False) # wandb logging wandb.watch(model, log='all') if config.NETWORK.TRAINING_STRATEGY in PolicyVec: wandb.watch(policy_model, log='all') if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': wandb.watch(teacher_model, log='all') # set up the initial learning rate initial_lr = config.TRAIN.LR # configure the optimizer try: optimizer = eval(f'optim_{config.TRAIN.OPTIMIZER}')( model=model, initial_lr=initial_lr, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WEIGHT_DECAY) except: raise ValueError(f'{config.TRAIN.OPTIMIZER}, not supported!!') if config.NETWORK.TRAINING_STRATEGY in PolicyVec: initial_lr_policy = config.POLICY.LR try: policy_optimizer = eval(f'optim_{config.POLICY.OPTIMIZER}')( model=model, initial_lr=initial_lr_policy, momentum=config.POLICY.MOMENTUM, weight_decay=config.POLICY.WEIGHT_DECAY) except: raise ValueError(f'{config.POLICY.OPTIMIZER}, not supported!!') # Load pre-trained model if config.NETWORK.PRETRAINED_MODEL != '': print( f"Loading the pretrained model from {config.NETWORK.PRETRAINED_MODEL} ..." ) pretrain_state_dict = torch.load( config.NETWORK.PRETRAINED_MODEL, map_location=lambda storage, loc: storage)['net_state_dict'] smart_model_load( model, pretrain_state_dict, loading_method=config.NETWORK.PRETRAINED_LOADING_METHOD) # Load the pre-trained teacher model if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation': # There must be a pretrained model to load from (but not in the case of an apprentice network) # assert config.TEACHER.PRETRAINED_MODEL != '', "No pre-trained model specified for the teacher" if config.TEACHER.PRETRAINED_MODEL != '': print( f"Loading the teacher network from {config.TEACHER.PRETRAINED_MODEL} ..." ) pretrain_state_dict = torch.load( config.TEACHER.PRETRAINED_MODEL, map_location=lambda storage, loc: storage)['net_state_dict'] smart_model_load( teacher_model, pretrain_state_dict, loading_method=config.TEACHER.PRETRAINED_LOADING_METHOD) # Set up the metrics train_metrics = TrainMetrics(config, allreduce=False) val_metrics = ValMetrics(config, allreduce=args.dist) # Set up the callbacks # batch end callbacks batch_end_callbacks = None # epoch end callbacks epoch_end_callbacks = [ Checkpoint(config, val_metrics), LRScheduler(config) ] if config.NETWORK.TRAINING_STRATEGY in PolicyVec: epoch_end_callbacks.append(LRSchedulerPolicy(config)) epoch_end_callbacks.append(VisualizationPlotter()) # At last call the training function from trainer train(config=config, net=model, optimizer=optimizer, train_loader=train_loader, train_metrics=train_metrics, val_loader=val_loader, val_metrics=val_metrics, policy_net=policy_model if config.NETWORK.TRAINING_STRATEGY in PolicyVec else None, policy_optimizer=policy_optimizer if config.NETWORK.TRAINING_STRATEGY in PolicyVec else None, teacher_net=teacher_model if config.NETWORK.TRAINING_STRATEGY == 'knowledge_distillation' else None, rank=rank if args.dist else None, batch_end_callbacks=batch_end_callbacks, epoch_end_callbacks=epoch_end_callbacks)
def train_net(args, config): # setup logger logger, final_output_path = create_logger(config.OUTPUT_PATH, args.cfg, config.DATASET.TRAIN_IMAGE_SET, split='train') model_prefix = os.path.join(final_output_path, config.MODEL_PREFIX) if args.log_dir is None: args.log_dir = os.path.join(final_output_path, 'tensorboard_logs') # pprint.pprint(args) # logger.info('training args:{}\n'.format(args)) # pprint.pprint(config) # logger.info('training config:{}\n'.format(pprint.pformat(config))) # manually set random seed if config.RNG_SEED > -1: random.seed(a=config.RNG_SEED) np.random.seed(config.RNG_SEED) torch.random.manual_seed(config.RNG_SEED) torch.cuda.manual_seed_all(config.RNG_SEED) torch.backends.cudnn.deterministic = True imgaug.random.seed(config.RNG_SEED) # cudnn torch.backends.cudnn.benchmark = False if args.cudnn_off: torch.backends.cudnn.enabled = False if args.dist: model = eval(config.MODULE)(config) local_rank = int(os.environ.get('LOCAL_RANK') or 0) config.GPUS = str(local_rank) torch.cuda.set_device(local_rank) master_address = os.environ['MASTER_ADDR'] master_port = int(os.environ['MASTER_PORT'] or 23456) world_size = int(os.environ['WORLD_SIZE'] or 1) rank = int(os.environ['RANK'] or 0) if rank == 0: pprint.pprint(args) logger.info('training args:{}\n'.format(args)) pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) if args.slurm: distributed.init_process_group(backend='nccl') else: try: distributed.init_process_group( backend='nccl', init_method='tcp://{}:{}'.format(master_address, master_port), world_size=world_size, rank=rank, group_name='mtorch') except RuntimeError: pass print( f'native distributed, size: {world_size}, rank: {rank}, local rank: {local_rank}' ) torch.cuda.set_device(local_rank) config.GPUS = str(local_rank) model = model.cuda() if not config.TRAIN.FP16: model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if rank == 0: summary_parameters( model.module if isinstance( model, torch.nn.parallel.DistributedDataParallel) else model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) writer = None if args.log_dir is not None: tb_log_dir = os.path.join(args.log_dir, 'rank{}'.format(rank)) if not os.path.exists(tb_log_dir): os.makedirs(tb_log_dir) writer = SummaryWriter(log_dir=tb_log_dir) batch_size = world_size * (sum(config.TRAIN.BATCH_IMAGES) if isinstance(config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult } for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({ 'params': [ p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT]) ] }) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format( config.TRAIN.OPTIMIZER)) total_gpus = world_size train_loader, train_sampler = make_dataloader(config, mode='train', distributed=True, num_replicas=world_size, rank=rank, expose_sampler=True) val_loader = make_dataloader(config, mode='val', distributed=True, num_replicas=world_size, rank=rank) else: pprint.pprint(args) logger.info('training args:{}\n'.format(args)) pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) #os.environ['CUDA_VISIBLE_DEVICES'] = config.GPUS model = eval(config.MODULE)(config) summary_parameters(model, logger) shutil.copy(args.cfg, final_output_path) shutil.copy(inspect.getfile(eval(config.MODULE)), final_output_path) num_gpus = len(config.GPUS.split(',')) # assert num_gpus <= 1 or (not config.TRAIN.FP16), "Not support fp16 with torch.nn.DataParallel. " \ # "Please use amp.parallel.DistributedDataParallel instead." if num_gpus > 1 and config.TRAIN.FP16: logger.warning("Not support fp16 with torch.nn.DataParallel.") config.TRAIN.FP16 = False total_gpus = num_gpus rank = None writer = SummaryWriter( log_dir=args.log_dir) if args.log_dir is not None else None if hasattr(model, 'setup_adapter'): logger.info('Setting up adapter modules!') model.setup_adapter() # model if num_gpus > 1: model = torch.nn.DataParallel( model, device_ids=[int(d) for d in config.GPUS.split(',')]).cuda() else: torch.cuda.set_device(int(config.GPUS)) model.cuda() # loader # train_set = 'train+val' if config.DATASET.TRAIN_WITH_VAL else 'train' train_loader = make_dataloader(config, mode='train', distributed=False) val_loader = make_dataloader(config, mode='val', distributed=False) train_sampler = None batch_size = num_gpus * (sum(config.TRAIN.BATCH_IMAGES) if isinstance( config.TRAIN.BATCH_IMAGES, list) else config.TRAIN.BATCH_IMAGES) if config.TRAIN.GRAD_ACCUMULATE_STEPS > 1: batch_size = batch_size * config.TRAIN.GRAD_ACCUMULATE_STEPS base_lr = config.TRAIN.LR * batch_size optimizer_grouped_parameters = [{ 'params': [p for n, p in model.named_parameters() if _k in n], 'lr': base_lr * _lr_mult } for _k, _lr_mult in config.TRAIN.LR_MULT] optimizer_grouped_parameters.append({ 'params': [ p for n, p in model.named_parameters() if all([_k not in n for _k, _ in config.TRAIN.LR_MULT]) ] }) if config.TRAIN.OPTIMIZER == 'SGD': optimizer = optim.SGD(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'Adam': optimizer = optim.Adam(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, weight_decay=config.TRAIN.WD) elif config.TRAIN.OPTIMIZER == 'AdamW': optimizer = AdamW(optimizer_grouped_parameters, lr=config.TRAIN.LR * batch_size, betas=(0.9, 0.999), eps=1e-6, weight_decay=config.TRAIN.WD, correct_bias=True) else: raise ValueError('Not support optimizer {}!'.format( config.TRAIN.OPTIMIZER)) # partial load pretrain state dict if config.NETWORK.PARTIAL_PRETRAIN != "": pretrain_state_dict = torch.load( config.NETWORK.PARTIAL_PRETRAIN, map_location=lambda storage, loc: storage)['state_dict'] prefix_change = [ prefix_change.split('->') for prefix_change in config.NETWORK.PARTIAL_PRETRAIN_PREFIX_CHANGES ] if len(prefix_change) > 0: pretrain_state_dict_parsed = {} for k, v in pretrain_state_dict.items(): no_match = True for pretrain_prefix, new_prefix in prefix_change: if k.startswith(pretrain_prefix): k = new_prefix + k[len(pretrain_prefix):] pretrain_state_dict_parsed[k] = v no_match = False break if no_match: pretrain_state_dict_parsed[k] = v pretrain_state_dict = pretrain_state_dict_parsed smart_partial_load_model_state_dict(model, pretrain_state_dict) # pretrained classifier # if config.NETWORK.CLASSIFIER_PRETRAINED: # print('Initializing classifier weight from pretrained word embeddings...') # answers_word_embed = [] # for k, v in model.state_dict().items(): # if 'word_embeddings.weight' in k: # word_embeddings = v.detach().clone() # break # for answer in train_loader.dataset.answer_vocab: # a_tokens = train_loader.dataset.tokenizer.tokenize(answer) # a_ids = train_loader.dataset.tokenizer.convert_tokens_to_ids(a_tokens) # a_word_embed = (torch.stack([word_embeddings[a_id] for a_id in a_ids], dim=0)).mean(dim=0) # answers_word_embed.append(a_word_embed) # answers_word_embed_tensor = torch.stack(answers_word_embed, dim=0) # for name, module in model.named_modules(): # if name.endswith('final_mlp'): # module[-1].weight.data = answers_word_embed_tensor.to(device=module[-1].weight.data.device) # metrics train_metrics_list = [ cls_metrics.Accuracy(allreduce=args.dist, num_replicas=world_size if args.dist else 1) ] val_metrics_list = [ cls_metrics.Accuracy(allreduce=args.dist, num_replicas=world_size if args.dist else 1), cls_metrics.RocAUC(allreduce=args.dist, num_replicas=world_size if args.dist else 1) ] for output_name, display_name in config.TRAIN.LOSS_LOGGERS: train_metrics_list.append( cls_metrics.LossLogger( output_name, display_name=display_name, allreduce=args.dist, num_replicas=world_size if args.dist else 1)) train_metrics = CompositeEvalMetric() val_metrics = CompositeEvalMetric() for child_metric in train_metrics_list: train_metrics.add(child_metric) for child_metric in val_metrics_list: val_metrics.add(child_metric) # epoch end callbacks epoch_end_callbacks = [] if (rank is None) or (rank == 0): epoch_end_callbacks = [ Checkpoint(model_prefix, config.CHECKPOINT_FREQUENT) ] validation_monitor = ValidationMonitor( do_validation, val_loader, val_metrics, host_metric_name='RocAUC', label_index_in_batch=config.DATASET.LABEL_INDEX_IN_BATCH, model_dir=os.path.dirname(model_prefix)) # optimizer initial lr before for group in optimizer.param_groups: group.setdefault('initial_lr', group['lr']) # resume/auto-resume if rank is None or rank == 0: smart_resume(model, optimizer, validation_monitor, config, model_prefix, logger) if args.dist: begin_epoch = torch.tensor(config.TRAIN.BEGIN_EPOCH).cuda() distributed.broadcast(begin_epoch, src=0) config.TRAIN.BEGIN_EPOCH = begin_epoch.item() # batch end callbacks batch_size = len(config.GPUS.split(',')) * config.TRAIN.BATCH_IMAGES batch_end_callbacks = [ Speedometer(batch_size, config.LOG_FREQUENT, batches_per_epoch=len(train_loader), epochs=config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) ] # setup lr step and lr scheduler if config.TRAIN.LR_SCHEDULE == 'plateau': print("Warning: not support resuming on plateau lr schedule!") lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.TRAIN.LR_FACTOR, patience=1, verbose=True, threshold=1e-4, threshold_mode='rel', cooldown=2, min_lr=0, eps=1e-8) elif config.TRAIN.LR_SCHEDULE == 'triangle': lr_scheduler = WarmupLinearSchedule( optimizer, config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, t_total=int(config.TRAIN.END_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS), last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) elif config.TRAIN.LR_SCHEDULE == 'step': lr_iters = [ int(epoch * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) for epoch in config.TRAIN.LR_STEP ] lr_scheduler = WarmupMultiStepLR( optimizer, milestones=lr_iters, gamma=config.TRAIN.LR_FACTOR, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_STEPS if config.TRAIN.WARMUP else 0, warmup_method=config.TRAIN.WARMUP_METHOD, last_epoch=int(config.TRAIN.BEGIN_EPOCH * len(train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS) - 1) else: raise ValueError("Not support lr schedule: {}.".format( config.TRAIN.LR_SCHEDULE)) if config.TRAIN.SWA: assert config.TRAIN.SWA_START_EPOCH < config.TRAIN.END_EPOCH if not config.TRAIN.DEBUG: true_epoch_step = len( train_loader) / config.TRAIN.GRAD_ACCUMULATE_STEPS else: true_epoch_step = 50 step_per_cycle = config.TRAIN.SWA_EPOCH_PER_CYCLE * true_epoch_step # swa_scheduler = torch.optim.lr_scheduler.CyclicLR( # optimizer, # base_lr=config.TRAIN.SWA_MIN_LR * batch_size, # max_lr=config.TRAIN.SWA_MAX_LR * batch_size, # cycle_momentum=False, # step_size_up=10, # step_size_down=step_per_cycle - 10) anneal_steps = max( 1, (config.TRAIN.END_EPOCH - config.TRAIN.SWA_START_EPOCH) // 4) * step_per_cycle anneal_steps = int(anneal_steps) swa_scheduler = SWALR(optimizer, anneal_epochs=anneal_steps, anneal_strategy='linear', swa_lr=config.TRAIN.SWA_MAX_LR * batch_size) else: swa_scheduler = None if config.TRAIN.ROC_STAR: assert config.TRAIN.ROC_START_EPOCH < config.TRAIN.END_EPOCH roc_star = RocStarLoss( delta=2.0, sample_size=config.TRAIN.ROC_SAMPLE_SIZE, sample_size_gamma=config.TRAIN.ROC_SAMPLE_SIZE * 2, update_gamma_each=config.TRAIN.ROC_SAMPLE_SIZE, ) else: roc_star = None # broadcast parameter and optimizer state from rank 0 before training start if args.dist: for v in model.state_dict().values(): distributed.broadcast(v, src=0) # for v in optimizer.state_dict().values(): # distributed.broadcast(v, src=0) best_epoch = torch.tensor(validation_monitor.best_epoch).cuda() best_val = torch.tensor(validation_monitor.best_val).cuda() distributed.broadcast(best_epoch, src=0) distributed.broadcast(best_val, src=0) validation_monitor.best_epoch = best_epoch.item() validation_monitor.best_val = best_val.item() # apex: amp fp16 mixed-precision training if config.TRAIN.FP16: # model.apply(bn_fp16_half_eval) model, optimizer = amp.initialize( model, optimizer, opt_level='O2', keep_batchnorm_fp32=False, loss_scale=config.TRAIN.FP16_LOSS_SCALE, min_loss_scale=32.0) if args.dist: model = Apex_DDP(model, delay_allreduce=True) # NOTE: final_model == model if not using SWA, else final_model == AveragedModel(model) final_model = train( model, optimizer, lr_scheduler, train_loader, train_sampler, train_metrics, config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH, logger, fp16=config.TRAIN.FP16, rank=rank, writer=writer, batch_end_callbacks=batch_end_callbacks, epoch_end_callbacks=epoch_end_callbacks, validation_monitor=validation_monitor, clip_grad_norm=config.TRAIN.CLIP_GRAD_NORM, gradient_accumulate_steps=config.TRAIN.GRAD_ACCUMULATE_STEPS, ckpt_path=config.TRAIN.CKPT_PATH, swa_scheduler=swa_scheduler, swa_start_epoch=config.TRAIN.SWA_START_EPOCH, swa_cycle_epoch=config.TRAIN.SWA_EPOCH_PER_CYCLE, swa_use_scheduler=config.TRAIN.SWA_SCHEDULE, roc_star=roc_star, roc_star_start_epoch=config.TRAIN.ROC_START_EPOCH, roc_interleave=config.TRAIN.ROC_INTERLEAVE, debug=config.TRAIN.DEBUG, ) return rank, final_model