def train(model, config, step, x, pre_model_file, model_file=None): model = model(config) print(model) model.eval() model_dic = model.state_dict() pretrained_dict = torch.load(pre_model_file, map_location='cpu') pretrained_dict = { 'features.' + k: v for k, v in pretrained_dict.items() if 'features.' + k in model_dic } print('*******', len(pretrained_dict)) model_dic.update(pretrained_dict) model.load_state_dict(model_dic) if step > 0: model.load_state_dict(torch.load(model_file, map_location='cpu')) print(model_file) else: print(pre_model_file) cuda(model) train_params = list(model.parameters()) lr = config.lr * config.batch_size_per_GPU if step >= 60000 * x: lr = lr / 10 if step >= 80000 * x: lr = lr / 10 print('lr ******************', lr) print('weight_decay ******************', config.weight_decay) if False: bias_p = [] weight_p = [] print(len(train_params)) for name, p in model.named_parameters(): if 'bias' in name: bias_p.append(p) else: weight_p.append(p) print(len(weight_p), len(bias_p)) opt = torch.optim.SGD( [{ 'params': weight_p, 'weight_decay': config.weight_decay, 'lr': lr }, { 'params': bias_p, 'lr': lr * config.bias_lr_factor }], momentum=0.9, ) else: bias_p = [] weight_p = [] bn_weight_p = [] print(len(train_params)) for name, p in model.named_parameters(): print(name, p.shape) if len(p.shape) == 1: if 'bias' in name: bias_p.append(p) else: bn_weight_p.append(p) else: weight_p.append(p) print(len(weight_p), len(bias_p), len(bn_weight_p)) opt = torch.optim.SGD( [{ 'params': weight_p, 'weight_decay': config.weight_decay, 'lr': lr }, { 'params': bn_weight_p, 'lr': lr }, { 'params': bias_p, 'lr': lr * config.bias_lr_factor }], momentum=0.9, ) scheduler = WarmupMultiStepLR(opt, [60000 * x, 80000 * x], warmup_factor=1 / 3, warmup_iters=500) dataset = Read_Data(config) dataloader = DataLoader(dataset, batch_size=config.batch_size_per_GPU, collate_fn=func, shuffle=True, drop_last=True, pin_memory=True, num_workers=2) epochs = 10000 flag = False print('start: step=', step) for epoch in range(epochs): for imgs, bboxes, num_b, num_H, num_W, masks in dataloader: loss = model(imgs, bboxes, num_b, num_H, num_W, masks) loss = loss / imgs.shape[0] opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(train_params, 5, norm_type=2) opt.step() scheduler.step() if step % 20 == 0: print(datetime.now(), 'loss:%.4f' % loss, 'rpn_cls_loss:%.4f' % model.a, 'rpn_box_loss:%.4f' % model.b, 'fast_cls_loss:%.4f' % model.c, 'fast_box_loss:%.4f' % model.d, 'mask_loss:%.4f' % model.f, model.fast_num, model.fast_num_P, opt.param_groups[0]['lr'], step) step += 1 # # if step == int(60000 * x) or step == int(80000 * x): # for param_group in opt.param_groups: # param_group['lr'] = param_group['lr'] / 10 if (step <= 10000 and step % 1000 == 0) or step % 5000 == 0 or step == 1: torch.save(model.state_dict(), './models/Mask_Rcnn_50_%d_1.pth' % step) if step >= 90010: flag = True break if flag: break torch.save(model.state_dict(), './models/Mask_Rcnn_50_final_1.pth')
def train_dist(model, config, step, x, pre_model_file, model_file=None): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() local_rank = args.local_rank print('******************* local_rank', local_rank) torch.cuda.set_device(local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) assert torch.distributed.is_initialized() batch_size = config.gpus * config.batch_size_per_GPU print('--------batch_size--------', batch_size) model = model(config) model.eval() model_dic = model.state_dict() pretrained_dict = torch.load(pre_model_file, map_location='cpu') pretrained_dict = {'features.' + k: v for k, v in pretrained_dict.items() if 'features.' + k in model_dic} print('pretrained_dict *******', len(pretrained_dict)) model_dic.update(pretrained_dict) model.load_state_dict(model_dic) if step > 0: model.load_state_dict(torch.load(model_file, map_location='cpu')) print(model_file) else: print(pre_model_file) model = torch.nn.parallel.DistributedDataParallel( model.cuda(), device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, ) lr = config.lr * config.batch_size_per_GPU if step >= 60000 * x: lr = lr / 10 if step >= 80000 * x: lr = lr / 10 print('lr ******************', lr) print('weight_decay ******************', config.weight_decay) train_params = list(model.parameters()) bias_p = [] weight_p = [] bn_weight_p = [] print(len(train_params)) for name, p in model.named_parameters(): print(name, p.shape) # if 'bias' in name: if len(p.shape) == 1: if 'bias' in name: bias_p.append(p) else: bn_weight_p.append(p) else: weight_p.append(p) print(len(weight_p), len(bias_p), len(bn_weight_p)) opt = torch.optim.SGD([{'params': weight_p, 'weight_decay': config.weight_decay, 'lr': lr}, {'params': bn_weight_p, 'lr': lr}, {'params': bias_p, 'lr': lr * config.bias_lr_factor}], momentum=0.9, ) dataset = Read_Data(config) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) dataloader = DataLoader(dataset, batch_size=config.batch_size_per_GPU, sampler=train_sampler, collate_fn=func, drop_last=True, pin_memory=True, num_workers=16) scheduler = WarmupMultiStepLR(opt, [60000 * x, 80000 * x], warmup_factor=1 / 3, warmup_iters=500) epochs = 10000 flag = False print('start: step=', step) if step > 100: for i in range(step): scheduler.step() for epoch in range(epochs): train_sampler.set_epoch(epoch) for imgs, bboxes, num_b, num_H, num_W, masks in dataloader: loss = model(imgs, bboxes, num_b, num_H, num_W, masks) loss = loss / imgs.shape[0] opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(train_params, 5, norm_type=2) opt.step() scheduler.step() if step % 20 == 0 and local_rank == 0: print(datetime.now(), 'loss:%.4f' % (loss.data), opt.param_groups[0]['lr'], step) pass step += 1 # if (step == int(60000 * x) or step == int(80000 * x)): # for param_group in opt.param_groups: # param_group['lr'] = param_group['lr'] / 10 # print('***************************', param_group['lr'], local_rank) if ((step <= 10000 and step % 1000 == 0) or step % 5000 == 0 or step == 1) and local_rank == 0: torch.save(model.module.state_dict(), './models/Mask_Rcnn_101_%dx_%d_1_%d.pth' % (x, step, local_rank)) if step >= 90010 * x: flag = True break if flag: break if local_rank == 0: torch.save(model.module.state_dict(), './models/Mask_Rcnn_101_%dx_final_1_%d.pth' % (x, local_rank))