def model_compile_para(self): compile_para = dict() compile_para["optimizer"] = tf.keras.optimizers.Adam( learning_rate=self.learning_rate) compile_para["loss"] = { "regression": SmoothL1Loss(), "classification": FocalLoss() } return compile_para
def compute_loss(outputs, labels, loss_method='binary'): loss = 0. if loss_method == 'binary': labels = labels.unsqueeze(1) loss = F.binary_cross_entropy(torch.sigmoid(outputs), labels) elif loss_method == 'cross_entropy': loss = F.cross_entropy(outputs, labels) elif loss_method == 'focal_loss': loss = FocalLoss()(outputs, labels) elif loss_method == 'ghmc': loss = GHMC()(outputs, labels) return loss
def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: self.config.problem_type = 'single_label_classification' elif self.config.problem_type != 'single_label_classification': raise NotImplementedError(self.__doc__) if self.config.problem_type == 'single_label_classification': # loss_fct = DiceLoss() loss_fct = FocalLoss(gamma=2, alpha=[4, 6, 1], reduction='sum') loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else output return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
def main(): lr = 5e-4 gamma = 0.2 num_classes = 21 epoch = 300 batch_size = 1 # data_path = '/mnt/storage/project/data/VOCdevkit/VOC2007' data_path = '~/datasets/VOC/VOCdevkit/VOC2007' # define data. data_set = LoadVocDataSets(data_path, 'trainval', AnnotationTransform(), PreProcess(resize=(600, 600))) # define model model = RetinaNet(num_classes) # define criterion criterion = FocalLoss(num_classes) # define optimizer optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) # set iteration numbers. epoch_size = len(data_set) // batch_size max_iter = epoch_size * epoch train_loss = 0 # start iteration for iteration in range(max_iter): if iteration % epoch_size == 0: # create batch iterator batch_iter = iter( DataLoader(data_set, batch_size, shuffle=True, num_workers=6, collate_fn=data_set.detection_collate)) images, loc_targets, cls_targets = next(batch_iter) optimizer.zero_grad() loc_preds, cls_preds = model(images) loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) loss.backward() optimizer.step() train_loss += loss.item() print('train_loss: %.3f ' % (loss.item()))
def compute_loss(self, logits, labels): if self.loss_type == "ce": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, 2), labels.view(-1)) elif self.loss_type == "focal": loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="mean") loss = loss_fct(logits.view(-1, 2), labels.view(-1)) elif self.loss_type == "dice": loss_fct = DiceLoss(with_logits=True, smooth=self.args.dice_smooth, ohem_ratio=self.args.dice_ohem, alpha=self.args.dice_alpha, square_denominator=self.args.dice_square, reduction="mean") loss = loss_fct(logits.view(-1, self.num_classes), labels) else: raise ValueError return loss
def __init__(self, args, ckpt): super(LossFunction, self).__init__() ckpt.write_log('[INFO] Making loss...') self.nGPU = args.nGPU self.args = args self.loss = [] for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'CrossEntropy': if args.if_labelsmooth: loss_function = CrossEntropyLabelSmooth( num_classes=args.num_classes) ckpt.write_log('[INFO] Label Smoothing On.') else: loss_function = nn.CrossEntropyLoss() elif loss_type == 'Triplet': loss_function = TripletLoss(args.margin) elif loss_type == 'GroupLoss': loss_function = GroupLoss(total_classes=args.num_classes, max_iter=args.T, num_anchors=args.num_anchors) elif loss_type == 'MSLoss': loss_function = MultiSimilarityLoss(margin=args.margin) elif loss_type == 'Focal': loss_function = FocalLoss(reduction='mean') elif loss_type == 'OSLoss': loss_function = OSM_CAA_Loss() elif loss_type == 'CenterLoss': loss_function = CenterLoss(num_classes=args.num_classes, feat_dim=args.feats) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function }) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) self.log = torch.Tensor()
def __init__(self, args, ckpt): super(Loss, self).__init__() print('[INFO] Making loss...') self.nGPU = args.nGPU self.args = args self.loss = [] self.loss_module = nn.ModuleList() self.device = torch.device('cpu' if args.cpu else 'cuda') for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'CrossEntropy': loss_function = nn.CrossEntropyLoss() elif loss_type == 'Triplet': loss_function = TripletSemihardLoss(self.device, args.margin) elif loss_type == 'FocalLoss': loss_function = FocalLoss(args.num_classes) self.loss.append({ #这是个列表,里面每个元素是字典,分别是类型,数量,loss函数 'type': loss_type, 'weight': float(weight), 'function': loss_function }) if len(self.loss) > 1: #如果是多损失 self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.load != '': self.load(ckpt.dir, cpu=args.cpu) if not args.cpu and args.nGPU > 1: #多gpu self.loss_module = nn.DataParallel(self.loss_module, range(args.nGPU))
def get_loss(config): """ returns the loss function """ loss = None if config['loss_config'] == 'multibox': loss = MultiBoxLoss(class_count=config['class_count'], threshold=config['threshold'], pos_neg_ratio=config['pos_neg_ratio'], use_gpu=config['use_gpu']) elif config['loss_config'] == 'focal': loss = FocalLoss(class_count=config['class_count'], threshold=config['threshold'], alpha=config['focal_alpha'], gamma=config['focal_gamma'], use_gpu=config['use_gpu']) return loss
drop_last=True) val_dataloader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=opts.n_workers, pin_memory=True, drop_last=False) print("Length of train dataloader = ", len(train_dataloader)) print("Length of validation dataloader = ", len(val_dataloader)) # define the model print("Loading model... ", opts.model, opts.model_depth) model, parameters = generate_model(opts) criterion = FocalLoss( alpha=[0.8, 0.4], gamma=2, criterion=nn.CrossEntropyLoss(reduction='none').cuda()).cuda() log_path = os.path.join(opts.result_path, opts.dataset) if not os.path.exists(log_path): os.makedirs(log_path) if opts.log == 1: if opts.resume_path1: begin_epoch = int(opts.resume_path1.split('/')[-1].split('_')[1]) epoch_logger = Logger(os.path.join( log_path, '{}_train_clip{}model{}{}.log'.format(opts.dataset, opts.sample_duration, opts.model, opts.model_depth)),
def main(args, logger): writer = SummaryWriter( log_dir=os.path.join('logs', args.dataset, args.model_name, args.loss)) train_loader, test_loader = load_data(args) if args.dataset == 'CIFAR10': num_classes = 10 elif args.dataset == 'CIFAR100': num_classes = 100 elif args.dataset == 'TINY_IMAGENET': num_classes = 200 elif args.dataset == 'IMAGENET': num_classes = 1000 print('Model name :: {}, Dataset :: {}, Num classes :: {}'.format( args.model_name, args.dataset, num_classes)) if args.model_name == 'mixnet_s': model = mixnet_s(num_classes=num_classes, dataset=args.dataset) # model = mixnet_s(num_classes=num_classes) elif args.model_name == 'mixnet_m': model = mixnet_m(num_classes=num_classes, dataset=args.dataset) elif args.model_name == 'mixnet_l': model = mixnet_l(num_classes=num_classes, dataset=args.dataset) elif args.model_name == 'ghostnet': model = ghostnet(num_classes=num_classes) elif args.model_name == 'ghostmishnet': model = ghostmishnet(num_classes=num_classes) elif args.model_name == 'ghosthmishnet': model = ghosthmishnet(num_classes=num_classes) elif args.model_name == 'ghostsharkfinnet': model = ghostsharkfinnet(num_classes=num_classes) elif args.model_name == 'mobilenetv2': model = models.mobilenet_v2(num_classes=num_classes) elif args.model_name == 'mobilenetv3_s': model = mobilenetv3_small(num_classes=num_classes) elif args.model_name == 'mobilenetv3_l': model = mobilenetv3_large(num_classes=num_classes) else: raise NotImplementedError if args.pretrained_model: filename = 'best_model_' + str(args.dataset) + '_' + str( args.model_name) + '_ckpt.tar' print('filename :: ', filename) file_path = os.path.join('./checkpoint', filename) checkpoint = torch.load(file_path) model.load_state_dict(checkpoint['state_dict']) start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] best_acc5 = checkpoint['best_acc5'] model_parameters = checkpoint['parameters'] print( 'Load model, Parameters: {0}, Start_epoch: {1}, Acc1: {2}, Acc5: {3}' .format(model_parameters, start_epoch, best_acc1, best_acc5)) logger.info( 'Load model, Parameters: {0}, Start_epoch: {1}, Acc1: {2}, Acc5: {3}' .format(model_parameters, start_epoch, best_acc1, best_acc5)) else: start_epoch = 1 best_acc1 = 0.0 best_acc5 = 0.0 if args.cuda: if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model = model.cuda() print("Number of model parameters: ", get_model_parameters(model)) logger.info("Number of model parameters: {0}".format( get_model_parameters(model))) if args.loss == 'ce': criterion = nn.CrossEntropyLoss() elif args.loss == 'focal': criterion = FocalLoss() else: raise NotImplementedError optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.001) lr_scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[30, 60], gamma=0.1) #learning rate decay for epoch in range(start_epoch, args.epochs + 1): # adjust_learning_rate(optimizer, epoch, args) train(model, train_loader, optimizer, criterion, epoch, args, logger, writer) acc1, acc5 = eval(model, test_loader, criterion, args) lr_scheduler.step() is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if is_best: best_acc5 = acc5 if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') filename = 'model_' + str(args.dataset) + '_' + str( args.model_name) + '_ckpt.tar' print('filename :: ', filename) parameters = get_model_parameters(model) if torch.cuda.device_count() > 1: save_checkpoint( { 'epoch': epoch, 'arch': args.model_name, 'state_dict': model.module.state_dict(), 'best_acc1': best_acc1, 'best_acc5': best_acc5, 'optimizer': optimizer.state_dict(), 'parameters': parameters, }, is_best, filename) else: save_checkpoint( { 'epoch': epoch, 'arch': args.model_name, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'best_acc5': best_acc5, 'optimizer': optimizer.state_dict(), 'parameters': parameters, }, is_best, filename) writer.add_scalar('Test/Acc1', acc1, epoch) writer.add_scalar('Test/Acc5', acc5, epoch) print(" Test best acc1:", best_acc1, " acc1: ", acc1, " acc5: ", acc5) writer.close()
def train(): args = parse_args() if args.use_tfboard: writer = SummaryWriter() # data loader print('load data') cfg.DATASET_NAME = args.dataset transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) dataset = VOCDataset(transform=transform, train=True) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate_fn) iter_per_epoch = int(len(dataset) / args.batch_size) # load model print('load model') model = RetinaNet() model.load_state_dict(torch.load('./pretrained_model/model.pth')) model.freeze_bn() if args.use_GPU: model = model.cuda() if args.mGPUs: model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.train() # criterion criterion = FocalLoss() # optimizer optimizer = optim.Adam(model.parameters(), lr=1e-5) print('start training') for epoch in range(args.epochs): train_data_iter = iter(dataloader) train_loss = 0 fg, tp = 0, 0 for step in range(iter_per_epoch): im_data, cls_targets, loc_targets, im_sizes = next(train_data_iter) if args.use_GPU: im_data = im_data.cuda() cls_targets = cls_targets.cuda() loc_targets = loc_targets.cuda() im_data = Variable(im_data) cls_targets = Variable(cls_targets) loc_targets = Variable(loc_targets) cls_preds, loc_preds = model(im_data) cls_loss, loc_loss = criterion(cls_preds, cls_targets, loc_preds, loc_targets) loss = cls_loss + loc_loss optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() # calculate classification acc cls_t = cls_targets.clone() cls_t = cls_t.view(-1, cfg.CLASS_NUM) cls_max, cls_argmax = torch.max(cls_t, dim=-1) fg_inds = torch.eq(cls_max, 1.) cls_p = cls_preds.clone() cls_p = cls_p.view(-1, cfg.CLASS_NUM) pred_info = torch.argmax(cls_p, dim=-1) tp += torch.sum(pred_info[fg_inds] == cls_argmax[fg_inds]) fg += fg_inds.sum() if (step + 1) % args.display_interval == 0: train_loss /= args.display_interval print('[%d epoch | %d step]cls_loss: %.3f | loc_loss: %.3f | avg_loss: %.3f | cls_acc: %.3f' % (epoch, step, cls_loss.item(), loc_loss.item(), train_loss, float(tp)/float(fg))) if args.use_tfboard: n_iter = epoch * iter_per_epoch + step + 1 writer.add_scalar('losses/loss', train_loss, n_iter) writer.add_scalar('losses/cls_loss', cls_loss.item(), n_iter) writer.add_scalar('losses/loc_loss', loc_loss.item(), n_iter) writer.add_scalar('acc/cls_acc', float(tp) / float(fg), n_iter) train_loss = 0 fg, tp = 0, 0 if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) if (epoch+1) % args.save_interval == 0: print('saving model') save_name = os.path.join(args.output_dir, 'retinanet_epoch_{}.pth'.format(epoch + 1)) torch.save({ 'model': model.state_dict(), 'epoch': epoch, }, save_name)
def train(train_config_file): """ Medical image segmentation training engine :param train_config_file: the input configuration file :return: None """ assert os.path.isfile(train_config_file), 'Config not found: {}'.format( train_config_file) # load config file train_cfg = load_config(train_config_file) # clean the existing folder if training from scratch model_folder = os.path.join(train_cfg.general.save_dir, train_cfg.general.model_scale) if os.path.isdir(model_folder): if train_cfg.general.resume_epoch < 0: shutil.rmtree(model_folder) os.makedirs(model_folder) else: os.makedirs(model_folder) # copy training and inference config files to the model folder shutil.copy(train_config_file, os.path.join(model_folder, 'train_config.py')) infer_config_file = os.path.join( os.path.join(os.path.dirname(__file__), 'config', 'infer_config.py')) shutil.copy(infer_config_file, os.path.join(train_cfg.general.save_dir, 'infer_config.py')) # enable logging log_file = os.path.join(model_folder, 'train_log.txt') logger = setup_logger(log_file, 'seg3d') # control randomness during training np.random.seed(train_cfg.general.seed) torch.manual_seed(train_cfg.general.seed) if train_cfg.general.num_gpus > 0: torch.cuda.manual_seed(train_cfg.general.seed) # dataset train_dataset = SegmentationDataset( mode='train', im_list=train_cfg.general.train_im_list, num_classes=train_cfg.dataset.num_classes, spacing=train_cfg.dataset.spacing, crop_size=train_cfg.dataset.crop_size, sampling_method=train_cfg.dataset.sampling_method, random_translation=train_cfg.dataset.random_translation, random_scale=train_cfg.dataset.random_scale, interpolation=train_cfg.dataset.interpolation, crop_normalizers=train_cfg.dataset.crop_normalizers) train_data_loader = DataLoader(train_dataset, batch_size=train_cfg.train.batchsize, num_workers=train_cfg.train.num_threads, pin_memory=True, shuffle=True) val_dataset = SegmentationDataset( mode='val', im_list=train_cfg.general.val_im_list, num_classes=train_cfg.dataset.num_classes, spacing=train_cfg.dataset.spacing, crop_size=train_cfg.dataset.crop_size, sampling_method=train_cfg.dataset.sampling_method, random_translation=train_cfg.dataset.random_translation, random_scale=train_cfg.dataset.random_scale, interpolation=train_cfg.dataset.interpolation, crop_normalizers=train_cfg.dataset.crop_normalizers) val_data_loader = DataLoader(val_dataset, batch_size=1, num_workers=1, shuffle=False) # define network net = GlobalLocalNetwork(train_dataset.num_modality(), train_cfg.dataset.num_classes) net.apply(kaiming_weight_init) max_stride = net.max_stride() if train_cfg.general.num_gpus > 0: net = nn.parallel.DataParallel(net, device_ids=list( range(train_cfg.general.num_gpus))) net = net.cuda() assert np.all(np.array(train_cfg.dataset.crop_size) % max_stride == 0), 'crop size not divisible by max stride' # training optimizer opt = optim.Adam(net.parameters(), lr=train_cfg.train.lr, betas=train_cfg.train.betas) # load checkpoint if resume epoch > 0 if train_cfg.general.resume_epoch >= 0: last_save_epoch = load_checkpoint(train_cfg.general.resume_epoch, net, opt, model_folder) else: last_save_epoch = 0 if train_cfg.loss.name == 'Focal': # reuse focal loss if exists loss_func = FocalLoss(class_num=train_cfg.dataset.num_classes, alpha=train_cfg.loss.obj_weight, gamma=train_cfg.loss.focal_gamma, use_gpu=train_cfg.general.num_gpus > 0) else: raise ValueError('Unknown loss function') writer = SummaryWriter(os.path.join(model_folder, 'tensorboard')) max_avg_dice = 0 for epoch_idx in range(1, train_cfg.train.epochs + 1): train_one_epoch(net, train_cfg.loss.branch_weight, opt, train_data_loader, train_cfg.dataset.down_sample_ratio, loss_func, train_cfg.general.num_gpus, epoch_idx + last_save_epoch, logger, writer, train_cfg.train.print_freq, train_cfg.debug.save_inputs, os.path.join(model_folder, 'debug')) # evaluation if epoch_idx % train_cfg.train.save_epochs == 0: avg_dice = evaluate_one_epoch( net, val_data_loader, train_cfg.dataset.crop_size, train_cfg.dataset.down_sample_ratio, train_cfg.dataset.crop_normalizers[0], Metrics(), [idx for idx in range(1, train_cfg.dataset.num_classes)], train_cfg.loss.branch_type) if max_avg_dice < avg_dice: max_avg_dice = avg_dice save_checkpoint(net, opt, epoch_idx, train_cfg, max_stride, 1) msg = 'epoch: {}, best dice ratio: {}' else: msg = 'epoch: {}, dice ratio: {}' msg = msg.format(epoch_idx, avg_dice) logger.info(msg)
def run_train(): assert torch.cuda.is_available(), 'Error: CUDA not found!' start_epoch = 0 # start from epoch 0 or last epoch # Data print('Load ListDataset') transform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) trainset = ListDataset(img_dir=config.img_dir, list_filename=config.train_list_filename, label_map_filename=config.label_map_filename, train=True, transform=transform, input_size=config.img_res) trainloader = torch.utils.data.DataLoader( trainset, batch_size=config.train_batch_size, shuffle=True, num_workers=8, collate_fn=trainset.collate_fn) testset = ListDataset(img_dir=config.img_dir, list_filename=config.test_list_filename, label_map_filename=config.label_map_filename, train=False, transform=transform, input_size=config.img_res) testloader = torch.utils.data.DataLoader(testset, batch_size=config.test_batch_size, shuffle=False, num_workers=8, collate_fn=testset.collate_fn) # Model net = RetinaNet() if os.path.exists(config.checkpoint_filename): print('Load saved checkpoint: {}'.format(config.checkpoint_filename)) checkpoint = torch.load(config.checkpoint_filename) net.load_state_dict(checkpoint['net']) best_loss = checkpoint['loss'] start_epoch = checkpoint['epoch'] else: print('Load pretrained model: {}'.format(config.pretrained_filename)) if not os.path.exists(config.pretrained_filename): import_pretrained_resnet() net.load_state_dict(torch.load(config.pretrained_filename)) net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) net.cuda() criterion = FocalLoss() optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4) # Training def train(epoch): print('\nEpoch: %d' % epoch) net.train() net.module.freeze_bn() train_loss = 0 total_batches = int( math.ceil(trainloader.dataset.num_samples / trainloader.batch_size)) for batch_idx, targets in enumerate(trainloader): inputs = targets[0] loc_targets = targets[1] cls_targets = targets[2] inputs = inputs.cuda() loc_targets = loc_targets.cuda() cls_targets = cls_targets.cuda() optimizer.zero_grad() loc_preds, cls_preds = net(inputs) loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) loss.backward() optimizer.step() train_loss += loss.data print('[%d| %d/%d] loss: %.3f | avg: %.3f' % (epoch, batch_idx, total_batches, loss.data, train_loss / (batch_idx + 1))) # Test def test(epoch): print('\nTest') net.eval() test_loss = 0 total_batches = int( math.ceil(testloader.dataset.num_samples / testloader.batch_size)) for batch_idx, targets in enumerate(testloader): inputs = targets[0] loc_targets = targets[1] cls_targets = targets[2] inputs = inputs.cuda() loc_targets = loc_targets.cuda() cls_targets = cls_targets.cuda() loc_preds, cls_preds = net(inputs) loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets) test_loss += loss.data print('[%d| %d/%d] loss: %.3f | avg: %.3f' % (epoch, batch_idx, total_batches, loss.data, test_loss / (batch_idx + 1))) # Save checkpoint global best_loss test_loss /= len(testloader) if test_loss < best_loss: print('Save checkpoint: {}'.format(config.checkpoint_filename)) state = { 'net': net.module.state_dict(), 'loss': test_loss, 'epoch': epoch, } if not os.path.exists(os.path.dirname(config.checkpoint_filename)): os.makedirs(os.path.dirname(config.checkpoint_filename)) torch.save(state, config.checkpoint_filename) best_loss = test_loss for epoch in range(start_epoch, start_epoch + 1000): train(epoch) test(epoch)
def compute_loss(self, start_logits, end_logits, span_logits, start_labels, end_labels, match_labels, start_label_mask, end_label_mask, answerable_cls_logits=None, answerable_cls_labels=None): batch_size, seq_len = start_logits.size()[0], start_logits.size()[1] start_float_label_mask = start_label_mask.view(-1).float() end_float_label_mask = end_label_mask.view(-1).float() match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand( -1, -1, seq_len) match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand( -1, seq_len, -1) match_label_mask = match_label_row_mask & match_label_col_mask # torch.triu -> returns the upper triangular part of a matrix or batch of matrces input, # the other elements of the result tensor are set to 0. # an named entity should have the start position which is smaller or equal to the end position. match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end if self.args.span_loss_candidates == "all": # naive mask float_match_label_mask = match_label_mask.view(batch_size, -1).float() else: # use only pred or golden start/end to compute match loss logits_size = start_logits.shape[-1] if logits_size == 1: start_preds, end_preds = start_logits > 0, end_logits > 0 start_preds, end_preds = torch.squeeze( start_preds, dim=-1), torch.squeeze(end_preds, dim=-1) elif logits_size == 2: start_preds, end_preds = torch.argmax( start_logits, dim=-1), torch.argmax(end_logits, dim=-1) else: raise ValueError if self.args.span_loss_candidates == "gold": match_candidates = ( (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0) & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0)) elif self.args.span_loss_candidates == "gold_random": gold_matrix = ( (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) > 0) & (end_labels.unsqueeze(-2).expand(-1, seq_len, -1) > 0)) data_generator = torch.Generator() data_generator.manual_seed(self.args.seed) random_matrix = torch.empty(batch_size, seq_len, seq_len).uniform_(0, 1) random_matrix = torch.bernoulli( random_matrix, generator=data_generator).long() random_matrix = random_matrix.cuda() match_candidates = torch.logical_or(gold_matrix, random_matrix) elif self.args.span_loss_candidates == "gold_pred": match_candidates = torch.logical_or( (start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)), (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))) elif self.args.span_loss_candidates == "gold_pred_random": gold_and_pred = torch.logical_or( (start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(-2).expand(-1, seq_len, -1)), (start_labels.unsqueeze(-1).expand(-1, -1, seq_len) & end_labels.unsqueeze(-2).expand(-1, seq_len, -1))) data_generator = torch.Generator() data_generator.manual_seed(self.args.seed) random_matrix = torch.empty(batch_size, seq_len, seq_len).uniform_(0, 1) random_matrix = torch.bernoulli( random_matrix, generator=data_generator).long() random_matrix = random_matrix.cuda() match_candidates = torch.logical_or(gold_and_pred, random_matrix) else: raise ValueError match_label_mask = match_label_mask & match_candidates float_match_label_mask = match_label_mask.view(batch_size, -1).float() if self.loss_type == "bce": start_end_logits_size = start_logits.shape[-1] if start_end_logits_size == 1: loss_fct = BCEWithLogitsLoss(reduction="none") start_loss = loss_fct(start_logits.view(-1), start_labels.view(-1).float()) start_loss = (start_loss * start_float_label_mask ).sum() / start_float_label_mask.sum() end_loss = loss_fct(end_logits.view(-1), end_labels.view(-1).float()) end_loss = (end_loss * end_float_label_mask ).sum() / end_float_label_mask.sum() elif start_end_logits_size == 2: loss_fct = CrossEntropyLoss(reduction='none') start_loss = loss_fct(start_logits.view(-1, 2), start_labels.view(-1)) start_loss = (start_loss * start_float_label_mask ).sum() / start_float_label_mask.sum() end_loss = loss_fct(end_logits.view(-1, 2), end_labels.view(-1)) end_loss = (end_loss * end_float_label_mask ).sum() / end_float_label_mask.sum() else: raise ValueError if span_logits is not None: loss_fct = BCEWithLogitsLoss(reduction="mean") select_span_logits = torch.masked_select( span_logits.view(-1), match_label_mask.view(-1).bool()) select_span_labels = torch.masked_select( match_labels.view(-1), match_label_mask.view(-1).bool()) match_loss = loss_fct(select_span_logits.view(-1, 1), select_span_labels.float().view(-1, 1)) else: match_loss = None if answerable_cls_logits is not None: loss_fct = BCEWithLogitsLoss(reduction="mean") answerable_loss = loss_fct( answerable_cls_logits.view(-1, 1), answerable_cls_labels.float().view(-1, 1)) else: answerable_loss = None elif self.loss_type in ["dice", "adaptive_dice"]: # compute span loss loss_fct = DiceLoss(with_logits=True, smooth=self.args.dice_smooth, ohem_ratio=self.args.dice_ohem, alpha=self.args.dice_alpha, square_denominator=self.args.dice_square, reduction="mean", index_label_position=False) start_end_logits_size = start_logits.shape[-1] start_loss = loss_fct( start_logits.view(-1, start_end_logits_size), start_labels.view(-1, 1), ) end_loss = loss_fct( end_logits.view(-1, start_end_logits_size), end_labels.view(-1, 1), ) if span_logits is not None: select_span_logits = torch.masked_select( span_logits.view(-1), match_label_mask.view(-1).bool()) select_span_labels = torch.masked_select( match_labels.view(-1), match_label_mask.view(-1).bool()) match_loss = loss_fct( select_span_logits.view(-1, 1), select_span_labels.view(-1, 1), ) else: match_loss = None if answerable_cls_logits is not None: answerable_loss = loss_fct(answerable_cls_logits.view(-1, 1), answerable_cls_labels.view(-1, 1)) else: answerable_loss = None else: loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="none") start_loss = loss_fct( FocalLoss.convert_binary_pred_to_two_dimension( start_logits.view(-1)), start_labels.view(-1)) start_loss = (start_loss * start_float_label_mask ).sum() / start_float_label_mask.sum() end_loss = loss_fct( FocalLoss.convert_binary_pred_to_two_dimension( end_logits.view(-1)), end_labels.view(-1)) end_loss = (end_loss * end_float_label_mask ).sum() / end_float_label_mask.sum() if answerable_cls_logits is not None: answerable_loss = loss_fct( FocalLoss.convert_binary_pred_to_two_dimension( answerable_cls_logits.view(-1)), answerable_cls_labels.view(-1)) answerable_loss = answerable_loss.mean() else: answerable_loss = None if span_logits is not None: match_loss = loss_fct( FocalLoss.convert_binary_pred_to_two_dimension( span_logits.view(-1)), match_labels.view(-1)) match_loss = match_loss * float_match_label_mask.view(-1) match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10) else: match_loss = None if answerable_loss is not None: return start_loss, end_loss, match_loss, answerable_loss return start_loss, end_loss, match_loss
def train(): args = parse_args(base_dir, model_dir, total_epochs, batch_size, lr) trainset = VOCDataset(base_dir=args.base_dir, split="train", transform=transforms.Compose([ RandomScaleCrop(550, 512), RandomHorizontalFlip(), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensor() ])) trainloader = DataLoader(trainset, batch_size=args.batch, shuffle=True, num_workers=4) print("starting loading the net and model") # net = Res34Unet(3, 21) net = PAN(3, 21) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.cuda: args.gpus = [int(x) for x in args.gpus.split(",")] net = nn.DataParallel(net, device_ids=args.gpus) net.to(device) # loss and optimizer criterion = FocalLoss() # criterion = MultiLovaszLoss() optimizer = optim.SGD(net.parameters(), lr=float(args.lr), momentum=0.9) # scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=5, # min_lr=0.00001) scheduler = StepLR(optimizer, step_size=40, gamma=0.1) start_epoch = 1 # Resuming training if args.resume_from is not None: if not os.path.isfile(args.resume_from): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume_from) #args.start_epoch = checkpoint['epoch'] if args.cuda: net.module.load_state_dict(checkpoint['model_state_dict']) else: net.load_state_dict(checkpoint['model_state_dict']) print("resuming training from {}, epoch:{}"\ .format(args.resume_from, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print("finishing loading the net and model") print("start training") for epoch in range(start_epoch, args.epoch + start_epoch): scheduler.step() epoch_loss = 0.0 running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs inputs, labels = data["image"].to(device), data["mask"].to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() epoch_loss += loss.item() if i % 10 == 9: # print every 10 mini-batches print("epoch %2d, [%5d / %5d], lr: %5g, loss: %.3f" % (epoch, (i + 1) * args.batch, len(trainset), scheduler.get_lr()[0], running_loss / 10)) # print("epoch %2d, [%5d / %5d], loss: %.5f" % # (epoch, (i + 1) * args.batch, len(trainset), running_loss / 10)) running_loss = 0.0 # scheduler.step(epoch_loss / math.ceil(len(trainset) / args.batch)) # save model torch.save( { "epoch": epoch, "model_state_dict": net.module.state_dict() if args.cuda else net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), # "lr": scheduler.get_lr()[0] }, os.path.join(model_dir, "epoch_{}.pth".format(epoch))) print("Finished training")
def train(args, tasks_archive, model): torch.backends.cudnn.benchmark = True if args.resume_ckp != '': logger.info('==> loading checkpoint: {}'.format(args.ckp)) checkpoint = torch.load(args.resume_ckp) model = nn.parallel.DataParallel(model) logger.info(' + model num_params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) if config.use_gpu: model.cuda() # required bofore optimizer? # cudnn.benchmark = True print(model) # especially useful for debugging model structure. # summary(model, input_size=tuple([config.num_modality]+config.patch_size)) # takes some time. comment during debugging. ouput each layer's out shape. # for name, m in model.named_modules(): # logger.info('module name:{}'.format(name)) # print(m) # lr lr = config.base_lr if args.resume_ckp != '': optimizer = checkpoint['optimizer'] else: optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.weight_decay) # # loss dice_loss = MulticlassDiceLoss() ce_loss = nn.CrossEntropyLoss() focal_loss = FocalLoss(gamma=2) # prep data tasks = args.tasks # list tb_loaders = list() # train batch loader len_loader = list() for task in tasks: tb_loader = tb_load(task) tb_loader.enQueue(tasks_archive[task]['fold' + str(args.fold)], config.patch_size) tb_loaders.append(tb_loader) len_loader.append(len(tb_loader)) min_len_loader = np.min(len_loader) # init train values if args.resume_ckp != '': trLoss_queue = checkpoint['trLoss_queue'] last_trLoss_ma = checkpoint['last_trLoss_ma'] else: trLoss_queue = deque( maxlen=config.trLoss_win ) # queue to store exponential moving average of total loss in last N epochs last_trLoss_ma = None # the previous one. trLoss_queue_list = [ deque(maxlen=config.trLoss_win) for i in range(len(tasks)) ] last_trLoss_ma_list = [None] * len(tasks) trLoss_ma_list = [None] * len(tasks) if args.resume_epoch > 0: start_epoch = args.resume_epoch + 1 iterations = args.resume_epoch * config.step_per_epoch + 1 else: start_epoch = 1 iterations = 1 logger.info('start epoch: {}'.format(start_epoch)) ## run train for epoch in range(start_epoch, config.max_epoch + 1): logger.info(' ----- training epoch {} -----'.format(epoch)) epoch_st_time = time.time() model.train() loss_epoch = 0.0 loss_epoch_list = [0] * len(tasks) num_batch_processed = 0 # growing num_batch_processed_list = [0] * len(tasks) for step in tqdm(range(config.step_per_epoch), desc='{}: epoch{}'.format(args.trainMode, epoch)): config.step = iterations config.task_idx = (iterations - 1) % len(tasks) config.task = tasks[config.task_idx] # import ipdb; ipdb.set_trace() # tb show lr config.writer.add_scalar('data/lr', lr, iterations - 1) st_time = time.time() for idx in range(len(tasks)): tb_loaders[idx].check_process() # import ipdb; ipdb.set_trace() (batchImg, batchLabel, batchWeight, batchAugs) = tb_loaders[config.task_idx].gen_batch( config.batch_size, config.patch_size) # logger.info('idx{}_{}, gen_batch time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() batchImg = torch.from_numpy(batchImg).float( ) # change all inputs to same torch tensor type batchLabel = torch.from_numpy(batchLabel).float() batchWeight = torch.from_numpy(batchWeight).float() if config.use_gpu: batchImg = batchImg.cuda() batchLabel = batchLabel.cuda() batchWeight = batchWeight.cuda() # logger.info('idx{}_{}, .cuda time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) optimizer.zero_grad() st_time = time.time() if config.trainMode in ["universal"]: output, share_map, para_map = model(batchImg) else: output = model(batchImg) # logger.info('idx{}_{}, model() time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) st_time = time.time() # tensorboard visualization of training for i in range(len(tasks)): if iterations > 200 and iterations % 1000 == i: tb_images([ batchImg[0, 0, ...], batchLabel[0, ...], torch.argmax(output[0, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}'.format( config.task_idx, config.task, 0, '_'.join(batchAugs[0]))) tb_images([ batchImg[config.batch_size - 1, 0, ...], batchLabel[config.batch_size - 1, ...], torch.argmax(output[config.batch_size - 1, ...], dim=0) ], [False, True, True], ['image', 'GT', 'PS'], iterations, tag='Train_idx{}_{}_batch{}_{}_step{}'.format( config.task_idx, config.task, config.batch_size - 1, '_'.join(batchAugs[config.batch_size - 1]), iterations - 1)) if config.trainMode == "universal": logger.info( 'share_map shape:{}, para_map shape:{}'.format( str(share_map.shape), str(para_map.shape))) tb_images([ para_map[0, :, 64, ...], share_map[0, :, 64, ...] ], [False, False], ['last_para_map', 'last_share_map'], iterations, tag='Train_idx{}_{}_para_share_maps_channels' .format(config.task_idx, config.task)) logger.info( '----- {}, train epoch {} time elapsed:{} -----'.format( config.task, epoch, tinies.timer(epoch_st_time, time.time()))) st_time = time.time() output_softmax = F.softmax(output, dim=1) loss = lovasz_softmax(output_softmax, batchLabel, ignore=10) + focal_loss(output, batchLabel) loss.backward() optimizer.step() # logger.info('idx{}_{}, backward time elapsed:{}'.format(config.task_idx, config.task, tinies.timer(st_time, time.time()))) # loss.data.item() config.writer.add_scalar('data/loss_step', loss.item(), iterations) config.writer.add_scalar( 'data/loss_step_idx{}_{}'.format(config.task_idx, config.task), loss.item(), iterations) loss_epoch += loss.item() num_batch_processed += 1 loss_epoch_list[config.task_idx] += loss.item() num_batch_processed_list[config.task_idx] += 1 iterations += 1 # import ipdb; ipdb.set_trace() if epoch % config.save_epoch == 0: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path) loss_epoch /= num_batch_processed config.writer.add_scalar('data/loss_epoch', loss_epoch, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] loss_epoch_list[idx] /= num_batch_processed_list[idx] config.writer.add_scalar( 'data/loss_epoch_idx{}_{}'.format(idx, task), loss_epoch_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() ### lr decay trLoss_queue.append(loss_epoch) trLoss_ma = np.asarray(trLoss_queue).mean( ) # moving average. What about exponential moving average config.writer.add_scalar('data/trLoss_ma', trLoss_ma, iterations - 1) for idx in range(len(tasks)): task = tasks[idx] trLoss_queue_list[idx].append(loss_epoch_list[idx]) trLoss_ma_list[idx] = np.asarray(trLoss_queue_list[idx]).mean( ) # moving average. What about exponential moving average config.writer.add_scalar( 'data/trLoss_ma_idx{}_{}'.format(idx, task), trLoss_ma_list[idx], iterations - 1) # import ipdb; ipdb.set_trace() #### online eval Eval_bool = False if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: Eval_bool = True elif lr < 1e-8: Eval_bool = True logger.info( 'lr is reduced to {}. Will do the last evaluation for all samples!' .format(lr)) else: pass # if epoch >= config.start_val_epoch and epoch % config.val_epoch == 0: if Eval_bool: eval(args, tasks_archive, model, epoch, iterations - 1) ## stop if lr is too low if lr < 1e-8: logger.info('lr is reduced to {}. Job Done!'.format(lr)) break ###### lr decay based on current task if len(trLoss_queue) == trLoss_queue.maxlen: if last_trLoss_ma and last_trLoss_ma - trLoss_ma < 1e-4: # 5e-3 lr /= 2 for param_group in optimizer.param_groups: param_group['lr'] = lr last_trLoss_ma = trLoss_ma ## save model when lr < 1e-8 if lr < 1e-8: ckp_path = os.path.join( config.log_dir, '{}_{}_epoch{}_{}.pth.tar'.format(args.trainMode, '_'.join(args.tasks), epoch, tinies.datestr())) torch.save( { 'epoch': epoch, 'model': model, 'model_state_dict': model.state_dict(), 'optimizer': optimizer, 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'trLoss_queue': trLoss_queue, 'last_trLoss_ma': last_trLoss_ma }, ckp_path)
def __init__(self, alpha=10, weight=None): super(CombinedLoss, self).__init__() self.alpha = alpha self.dice_loss = MultiDiceLoss(weight) self.focal_loss = FocalLoss(weight)