def train(): #1: train 결과를 저장할 폴더를 생성 if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) #2: MSCOCO에서 제공하는 API를 통해 train dataset을 준비한다. dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) # 만약 train-validation기법을 사용한다면, eval dataset도 준비한다. if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) #3: 구현한 yolact() class의 객체를 만들고 train모드로 설정. #주의 : net과 yolact_net은 메모리에 저장된 같은 객체를 공유한다. # 다만 net은 이후에 yolact와 MultiBoxLoss가 결함되어 train을 위한 # 통합된 객체로 다시 정의되기 때문에 yolact넷 객체에만 따로 접근하기 위해 # yolact_net을 deep copy본으로 가지고 있는다. yolact_net = Yolact() net = yolact_net net.train() ####################################################################### #######RESUME 관련##################################################### #4: args.log와 args.resume은 train도중 log를 남기는 것과, train이 # 불가피하게 중도에 정지되었을 경우, 중단 지점부터 재시작할 수 있도록 # 기능을 만든 것이므로 필요한 경우에만 더 자세히 보도록 하자. if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) #######END############################################################# ####################################################################### #5: yolact의 optimizer와 loss함수를 설정한다. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio) #6: 멀티 GPU를 사용하는 경우 각 GPU에 batch size를 분할해준다. # 만약 총 Batch size가 맞지 않으면 뭔가 잘못된 것이므로 프로그램 종료. if args.batch_alloc is not None: args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print( 'Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) #7: 현재까지 설정된 net과 loss 함수를 엮어 더 통합된 net으로 만듬. # 이제 net을 호출하면, bbox를 detection하고, fast nms를 거쳐 한 번 # 필터링을 한 후, ground truth와 비교하여 loss를 계산하고, 이 과정을 # 멀티 GPU일 경우 알아서 각 device에 작업을 분할해준다. # yolact_net은 net에 포함된 yolact()만을 가리킨다. net = CustomDataParallel(NetLoss(net, criterion)) if args.cuda: net = net.cuda() #8: yolact_net의 batch_normalization layer를 모두 false로 만든 뒤에 # 0만을 가지고 있는 zero_tensor를 모델에 통과시켜, 파라미터를 초기화시켜준다. # 그 후에 다시 batch_normalization layer를 train모드로 바꿔준다. # 굳이 이런 과정을 거치는 이유는 저자가 batch_normalization에 미리 넣어놓은 # 평균/분산 값은 초기화하고 싶지 않기 때문이다. if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means (torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) #9: loss counters # bbox의 위치에 대한 loss와, class confidence에 대한 loss 를 담을 변수를 생성하고, # batch_size와 dataset의 크기에 맞는 1 epoch의 size와 몇 epoch를 돌려야하는지 구한다. loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) #cw : 음수입력을 허용치 않기 위해... GOOD last_time = time.time() epoch_size = len(dataset) // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) #10:Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index # step_index는 learning rate decay를 위해 사용하는 index이다. # data_loader는 train중에 순서대로 데이터셋을 준비해서 넘겨주는 class이다. # 여기서 객체를 만들어 저장한다. step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) #11:특정 epoch와 iteration에 도달했을 때, 중간 과정을 save_path에 저장하기 위한 # 람다 함수를 정의하고, time_avg와 loss_avg는 MovingAverage 클래스의 객체로써 # 훈련 중간 과정의 loss를 이동평균 값으로 보여주기 위해 선언되는 객체이다. save_path = lambda epoch, iteration: SavePath( cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order loss_avgs = {k: MovingAverage(100) for k in loss_types} #12: main train이 시작되는 부분(#A ~ #F) print('Begin training!') print() # A # try-except를 사용하여 ctrl+c(keyboardInterrupt)를 통해 # 훈련을 중단하고 진행내용은 저장할 수 있다. # 중단지점부터 재시작하고 싶으면 train.py실행 시 --resume인자를 사용한다. try: #9에서 계산된 num_epochs만큼 반복. for epoch in range(num_epochs): # B # --resume을 이용해 시작했다면, 재시작 iter에 도달할 때까지 continue, # 또한 data_loader에서 data를 불러오며 loss를 계산하는데, # 도중에 목표 iteration에 도달했으면 break하여 1 epoch를 종료한다. if (epoch + 1) * epoch_size < iteration: continue for datum in data_loader: # 목표한만큼 훈련이 되었다면, 종료한다. # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch + 1) * epoch_size: break # 목표로 설정된 반복횟수가 max_iter보다 크면 max_iter에서 훈련을 마친다. # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # 특정 iteration에 config값이 바뀌도록 할 경우의 작업을 수행한다. # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [ x for x in cfg.delayed_settings if x[0] > iteration ] # C # [learning rate 조정] # train시작한지 얼마 안되었을 경우(lr_warmup_until기준) 훈련을 조금 가속시키기 위해 조정. # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # 특정 iteration에 도달할 때마다 learning rate decay수행. # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len( cfg.lr_steps ) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer, args.lr * (args.gamma**step_index)) # D # loss 함수 계산. # Zero the grad to get ready to compute gradients optimizer.zero_grad() # Forward Propagation을 수행하고 수행 결과로 loss 함수를 통해 1 iteration의 loss를 계산한다. # 구체적인 동작은 Backbone.py의 resnet101, yolact.py의 yolact, MultiBoxLoss.py의 MultiBoxLoss 클래스를 모두 보아야 한다. # (see CustomDataParallel and NetLoss) losses = net(datum) losses = {k: (v).mean() for k, v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # E # Backward Propagation을 수행하고, # 계산가능한 값일 경우, optimizer.step()을 통해 parameters에 적용 # Backprop loss.backward() # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer.step() # F # train진행 과정에서 소요 시간과, 중간 loss값을 출력하여 중간 성과를 # 파악 할 수 있도록 해주는 파트. # Add the loss to the moving average for bookkeeping for k in losses: loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str( datetime.timedelta(seconds=(cfg.max_iter - iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) # log를 파일로 기록 if args.log: precision = 5 loss_info = { k: round(losses[k].item(), precision) for k in losses } loss_info['T'] = round(loss.item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0 ) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu # ~F # 1번 반복하면, 1 iter증가. iteration += 1 # 주기마다 진행과정을 저장하는 작업 수행. if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # train-validation으로 작업을 수행하는 경우, # 1 epoch를 돌렸을 때 validation 주기에 도달한 epoch였으면 validate 1회 진행하여 mAP측정. if args.validation_epoch > 0: if epoch % args.validation_epoch == 0 and epoch > 0: compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) #13: Ctrl + c를 이용하여 훈련을 중단했을 경우, save_foler에 weights를 저장하고 중단하여 # 다음에 다시 재시작할 수 있도록 한다. except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights( save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration))
def train(): if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) print("dataset:", dataset[0]) if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) # Parallel wraps the underlying module, but when saving and loading we don't want that yolact_net = Yolact() net = yolact_net net.train() if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio) if args.batch_alloc is not None: args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print( 'Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) net = CustomDataParallel(NetLoss(net, criterion)) if args.cuda: net = net.cuda() # Initialize everything if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) # loss counters loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) last_time = time.time() epoch_size = len(dataset) // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) save_path = lambda epoch, iteration: SavePath( cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order loss_avgs = {k: MovingAverage(100) for k in loss_types} print('Begin training!') print() # try-except so you can use ctrl+c to save early and stop training try: for epoch in range(num_epochs): # Resume from start_iter if (epoch + 1) * epoch_size < iteration: continue for datum in data_loader: # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch + 1) * epoch_size: break # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [ x for x in cfg.delayed_settings if x[0] > iteration ] # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len( cfg.lr_steps ) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer, args.lr * (args.gamma**step_index)) # Zero the grad to get ready to compute gradients optimizer.zero_grad() # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss) losses = net(datum) losses = {k: (v).mean() for k, v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # Backprop loss.backward( ) # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer.step() # Add the loss to the moving average for bookkeeping for k in losses: loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str( datetime.timedelta(seconds=(cfg.max_iter - iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) if args.log: precision = 5 loss_info = { k: round(losses[k].item(), precision) for k in losses } loss_info['T'] = round(losses[k].item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0 ) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu iteration += 1 if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # This is done per epoch if args.validation_epoch > 0: if epoch % args.validation_epoch == 0 and epoch > 0: compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights( save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration))
def train(): if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) dataset = COCODetection(image_path=cfg.dataset.train_images, info_file=cfg.dataset.train_info, transform=SSDAugmentation(MEANS)) if args.validation_epoch > 0: setup_eval() val_dataset = COCODetection(image_path=cfg.dataset.valid_images, info_file=cfg.dataset.valid_info, transform=BaseTransform(MEANS)) # Parallel wraps the underlying module, but when saving and loading we don't want that yolact_net = Yolact() net = yolact_net net.train() print('\n--- Generator created! ---') # NOTE # I maunally set the original image size and seg size as 138 # might change in the future, for example 550 if cfg.pred_seg: dis_size = 138 dis_net = Discriminator_Wgan(i_size = dis_size, s_size = dis_size) # Change the initialization inside the dis_net class inside # set the dis net's initial parameter values # dis_net.apply(gan_init) dis_net.train() print('--- Discriminator created! ---\n') if args.log: log = Log(cfg.name, args.log_folder, dict(args._get_kwargs()), overwrite=(args.resume is None), log_gpu_stats=args.log_gpu) # I don't use the timer during training (I use a different timing method). # Apparently there's a race condition with multiple GPUs, so disable it just to be safe. timer.disable_all() # Both of these can set args.resume to None, so do them before the check if args.resume == 'interrupt': args.resume = SavePath.get_interrupt(args.save_folder) elif args.resume == 'latest': args.resume = SavePath.get_latest(args.save_folder, cfg.name) if args.resume is not None: print('Resuming training, loading {}...'.format(args.resume)) yolact_net.load_weights(args.resume) if args.start_iter == -1: args.start_iter = SavePath.from_str(args.resume).iteration else: print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) # optimizer_gen = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, # weight_decay=args.decay) # if cfg.pred_seg: # optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr, momentum=args.momentum, # weight_decay=args.decay) # schedule_dis = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6) # NOTE: Using the Ranger Optimizer for the generator optimizer_gen = Ranger(net.parameters(), lr = args.lr, weight_decay=args.decay) # optimizer_gen = optim.RMSprop(net.parameters(), lr = args.lr) # FIXME: Might need to modify the lr in the optimizer carefually # check this # def make_D_optimizer(cfg, model): # params = [] # for key, value in model.named_parameters(): # if not value.requires_grad: # continue # lr = cfg.SOLVER.BASE_LR/5.0 # weight_decay = cfg.SOLVER.WEIGHT_DECAY # if "bias" in key: # lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR/5.0 # weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS # params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] # optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) # return optimizer if cfg.pred_seg: optimizer_dis = optim.SGD(dis_net.parameters(), lr=cfg.dis_lr) # optimizer_dis = optim.RMSprop(dis_net.parameters(), lr = cfg.dis_lr) schedule_dis = ReduceLROnPlateau(optimizer_dis, mode = 'min', patience=6, min_lr=1E-6) criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, negpos_ratio=cfg.ohem_negpos_ratio, pred_seg=cfg.pred_seg) # criterion_dis = nn.BCELoss() # Take the advice from WGAN criterion_dis = DiscriminatorLoss_Maskrcnn() criterion_gen = GeneratorLoss_Maskrcnn() if args.batch_alloc is not None: # e.g. args.batch_alloc: 24,24 args.batch_alloc = [int(x) for x in args.batch_alloc.split(',')] if sum(args.batch_alloc) != args.batch_size: print('Error: Batch allocation (%s) does not sum to batch size (%s).' % (args.batch_alloc, args.batch_size)) exit(-1) net = CustomDataParallel(NetLoss(net, criterion, pred_seg=cfg.pred_seg)) if args.cuda: net = net.cuda() # NOTE if cfg.pred_seg: dis_net = nn.DataParallel(dis_net) dis_net = dis_net.cuda() # Initialize everything if not cfg.freeze_bn: yolact_net.freeze_bn() # Freeze bn so we don't kill our means yolact_net(torch.zeros(1, 3, cfg.max_size, cfg.max_size).cuda()) if not cfg.freeze_bn: yolact_net.freeze_bn(True) # loss counters loc_loss = 0 conf_loss = 0 iteration = max(args.start_iter, 0) last_time = time.time() epoch_size = len(dataset) // args.batch_size num_epochs = math.ceil(cfg.max_iter / epoch_size) # Which learning rate adjustment step are we on? lr' = lr * gamma ^ step_index step_index = 0 data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, shuffle=True, collate_fn=detection_collate, pin_memory=True) # NOTE val_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers*2, shuffle=True, collate_fn=detection_collate, pin_memory=True) save_path = lambda epoch, iteration: SavePath(cfg.name, epoch, iteration).get_path(root=args.save_folder) time_avg = MovingAverage() global loss_types # Forms the print order # TODO: global command can modify global variable inside of the function. loss_avgs = { k: MovingAverage(100) for k in loss_types } # NOTE # Enable AMP amp_enable = cfg.amp scaler = torch.cuda.amp.GradScaler(enabled=amp_enable) print('Begin training!') print() # try-except so you can use ctrl+c to save early and stop training try: for epoch in range(num_epochs): # Resume from start_iter if (epoch+1)*epoch_size < iteration: continue for datum in data_loader: # Stop if we've reached an epoch if we're resuming from start_iter if iteration == (epoch+1)*epoch_size: break # Stop at the configured number of iterations even if mid-epoch if iteration == cfg.max_iter: break # Change a config setting if we've reached the specified iteration changed = False for change in cfg.delayed_settings: if iteration >= change[0]: changed = True cfg.replace(change[1]) # Reset the loss averages because things might have changed for avg in loss_avgs: avg.reset() # If a config setting was changed, remove it from the list so we don't keep checking if changed: cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration] # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer_gen, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer_gen, args.lr * (args.gamma ** step_index)) # NOTE if cfg.pred_seg: # ====== GAN Train ====== # train the gen and dis in different iteration # it_alter_period = iteration % (cfg.gen_iter + cfg.dis_iter) # FIXME: # present_time = time.time() for _ in range(cfg.dis_iter): # freeze_pretrain(yolact_net, freeze=False) # freeze_pretrain(net, freeze=False) # freeze_pretrain(dis_net, freeze=False) # if it_alter_period == 0: # print('--- Generator freeze ---') # print('--- Discriminator training ---') if cfg.amp: with torch.cuda.amp.autocast(): # ----- Discriminator part ----- # seg_list is the prediction mask # can be regarded as generated images from YOLACT # pred_list is the prediction label # seg_list dim: list of (138,138,instances) # pred_list dim: list of (instances) losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) # input image size is [b, 3, 550, 550] # downsample to [b, 3, seg_h, seg_w] image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # Because in the discriminator training, we do not # want the gradient flow back to the generator part # we detach seg_clas (mask_clas come the data, does not have grad) output_pred = dis_net(img = image.detach(), seg = seg_clas.detach()) output_grou = dis_net(img = image.detach(), seg = mask_clas.detach()) # p = elem_mul_p.squeeze().permute(1,2,0).cpu().detach().numpy() # g = elem_mul_g.squeeze().permute(1,2,0).cpu().detach().numpy() # image = image.squeeze().permute(1,2,0).cpu().detach().numpy() # from PIL import Image # seg_PIL = Image.fromarray(p, 'RGB') # mask_PIL = Image.fromarray(g, 'RGB') # seg_PIL.save('mul_seg.png') # mask_PIL.save('mul_mask.png') # raise RuntimeError # from matplotlib import pyplot as plt # fig, (ax1, ax2) = plt.subplots(1,2) # ax1.imshow(mask_show) # ax2.imshow(seg_show) # plt.show(block=False) # plt.pause(2) # plt.close() # if iteration % (cfg.gen_iter + cfg.dis_iter) == 0: # print(f'Probability of fake is fake: {output_pred.mean().item():.2f}') # print(f'Probability of real is real: {output_grou.mean().item():.2f}') # 0 for Fake/Generated # 1 for True/Ground Truth # fake_label = torch.zeros(b) # real_label = torch.ones(b) # Advice of practical implementation # from https://arxiv.org/abs/1611.08408 # loss_pred = -criterion_dis(output_pred,target=real_label) # loss_pred = criterion_dis(output_pred,target=fake_label) # loss_grou = criterion_dis(output_grou,target=real_label) # loss_dis = loss_pred + loss_grou # Wasserstein Distance (Earth-Mover) loss_dis = criterion_dis(input=output_grou,target=output_pred) # Backprop the discriminator # Scales loss. Calls backward() on scaled loss to create scaled gradients. scaler.scale(loss_dis).backward() scaler.step(optimizer_dis) scaler.update() optimizer_dis.zero_grad() # clip the updated parameters _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()] # ----- Generator part ----- # freeze_pretrain(yolact_net, freeze=False) # freeze_pretrain(net, freeze=False) # freeze_pretrain(dis_net, freeze=False) # if it_alter_period == (cfg.dis_iter+1): # print('--- Generator training ---') # print('--- Discriminator freeze ---') # FIXME: # print(f'dis time pass: {time.time()-present_time:.2f}') # FIXME: # present_time = time.time() with torch.cuda.amp.autocast(): losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # Perform forward pass of all-fake batch through D # NOTE this seg_clas CANNOT detach, in order to flow the # gradient back to the generator # output = dis_net(img = image, seg = seg_clas) # Since the log(1-D(G(x))) not provide sufficient gradients # We want log(D(G(x)) instead, this can be achieve by # use the real_label as target. # This step is crucial for the information of discriminator # to go into the generator. # Calculate G's loss based on this output # real_label = torch.ones(b) # loss_gen = criterion_dis(output,target=real_label) # GAN MaskRCNN output_pred = dis_net(img = image, seg = seg_clas) output_grou = dis_net(img = image, seg = mask_clas) # Advice from WGAN # loss_gen = -torch.mean(output) loss_gen = criterion_gen(input=output_grou,target=output_pred) # since the dis is already freeze, the gradients will only # record the YOLACT losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) loss += loss_gen # Generator backprop scaler.scale(loss).backward() scaler.step(optimizer_gen) scaler.update() optimizer_gen.zero_grad() # FIXME: # print(f'gen time pass: {time.time()-present_time:.2f}') # print('GAN part over') else: losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) output_pred = dis_net(img = image.detach(), seg = seg_clas.detach()) output_grou = dis_net(img = image.detach(), seg = mask_clas.detach()) loss_dis = criterion_dis(input=output_grou,target=output_pred) loss_dis.backward() optimizer_dis.step() optimizer_dis.zero_grad() _ = [par.data.clamp_(-cfg.clip_value, cfg.clip_value) for par in dis_net.parameters()] # ----- Generator part ----- # FIXME: # print(f'dis time pass: {time.time()-present_time:.2f}') # FIXME: # present_time = time.time() losses, seg_list, pred_list = net(datum) seg_clas, mask_clas, b, seg_size = seg_mask_clas(seg_list, pred_list, datum) image_list = [img.to(cuda0) for img in datum[0]] image = interpolate(torch.stack(image_list), size = seg_size, mode='bilinear',align_corners=False) # GAN MaskRCNN output_pred = dis_net(img = image, seg = seg_clas) output_grou = dis_net(img = image, seg = mask_clas) loss_gen = criterion_gen(input=output_grou,target=output_pred) # since the dis is already freeze, the gradients will only # record the YOLACT losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) loss += loss_gen loss.backward() # Do this to free up vram even if loss is not finite optimizer_gen.zero_grad() if torch.isfinite(loss).item(): # since the optimizer_gen is for YOLACT only # only the gen will be updated optimizer_gen.step() # FIXME: # print(f'gen time pass: {time.time()-present_time:.2f}') # print('GAN part over') else: # ====== Normal YOLACT Train ====== # Zero the grad to get ready to compute gradients optimizer_gen.zero_grad() # Forward Pass + Compute loss at the same time (see CustomDataParallel and NetLoss) losses = net(datum) losses = { k: (v).mean() for k,v in losses.items() } # Mean here because Dataparallel loss = sum([losses[k] for k in losses]) # no_inf_mean removes some components from the loss, so make sure to backward through all of it # all_loss = sum([v.mean() for v in losses.values()]) # Backprop loss.backward() # Do this to free up vram even if loss is not finite if torch.isfinite(loss).item(): optimizer_gen.step() # Add the loss to the moving average for bookkeeping _ = [loss_avgs[k].add(losses[k].item()) for k in losses] # for k in losses: # loss_avgs[k].add(losses[k].item()) cur_time = time.time() elapsed = cur_time - last_time last_time = cur_time # Exclude graph setup from the timing information if iteration != args.start_iter: time_avg.add(elapsed) if iteration % 10 == 0: eta_str = str(datetime.timedelta(seconds=(cfg.max_iter-iteration) * time_avg.get_avg())).split('.')[0] total = sum([loss_avgs[k].get_avg() for k in losses]) loss_labels = sum([[k, loss_avgs[k].get_avg()] for k in loss_types if k in losses], []) if cfg.pred_seg: print(('[%3d] %7d ||' + (' %s: %.3f |' * len(losses)) + ' T: %.3f || ETA: %s || timer: %.3f') % tuple([epoch, iteration] + loss_labels + [total, eta_str, elapsed]), flush=True) # print(f'Generator loss: {loss_gen:.2f} | Discriminator loss: {loss_dis:.2f}') # Loss Key: # - B: Box Localization Loss # - C: Class Confidence Loss # - M: Mask Loss # - P: Prototype Loss # - D: Coefficient Diversity Loss # - E: Class Existence Loss # - S: Semantic Segmentation Loss # - T: Total loss if args.log: precision = 5 loss_info = {k: round(losses[k].item(), precision) for k in losses} loss_info['T'] = round(loss.item(), precision) if args.log_gpu: log.log_gpu_stats = (iteration % 10 == 0) # nvidia-smi is sloooow log.log('train', loss=loss_info, epoch=epoch, iter=iteration, lr=round(cur_lr, 10), elapsed=elapsed) log.log_gpu_stats = args.log_gpu iteration += 1 if iteration % args.save_interval == 0 and iteration != args.start_iter: if args.keep_latest: latest = SavePath.get_latest(args.save_folder, cfg.name) print('Saving state, iter:', iteration) yolact_net.save_weights(save_path(epoch, iteration)) if args.keep_latest and latest is not None: if args.keep_latest_interval <= 0 or iteration % args.keep_latest_interval != args.save_interval: print('Deleting old save...') os.remove(latest) # This is done per epoch if args.validation_epoch > 0: # NOTE: Validation loss # if cfg.pred_seg: # net.eval() # dis_net.eval() # cfg.gan_eval = True # with torch.no_grad(): # for datum in tqdm(val_loader, desc='GAN Validation'): # losses, seg_list, pred_list = net(datum) # losses, seg_list, pred_list = net(datum) # # TODO: warp below as a function # seg_list = [v.permute(2,1,0).contiguous() for v in seg_list] # b = len(seg_list) # batch size # _, seg_h, seg_w = seg_list[0].size() # seg_clas = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w) # mask_clas = torch.zeros(b, cfg.num_classes-1, seg_h, seg_w) # target_list = [target for target in datum[1][0]] # mask_list = [interpolate(mask.unsqueeze(0), size = (seg_h,seg_w),mode='bilinear', \ # align_corners=False).squeeze() for mask in datum[1][1]] # for idx in range(b): # for i, (pred, i_target) in enumerate(zip(pred_list[idx], target_list[idx])): # seg_clas[idx, pred, ...] += seg_list[idx][i,...] # mask_clas[idx, i_target[-1].long(), ...] += mask_list[idx][i,...] # seg_clas = torch.clamp(seg_clas, 0, 1) # image = interpolate(torch.stack(datum[0]), size = (seg_h,seg_w), # mode='bilinear',align_corners=False) # real_label = torch.ones(b) # output_pred = dis_net(img = image, seg = seg_clas) # output_grou = dis_net(img = image, seg = mask_clas) # loss_pred = -criterion_dis(output_pred,target=real_label) # loss_grou = criterion_dis(output_grou,target=real_label) # loss_dis = loss_pred + loss_grou # losses = { k: (v).mean() for k,v in losses.items() } # loss = sum([losses[k] for k in losses]) # val_loss = loss - cfg.lambda_dis*loss_dis # schedule_dis.step(loss_dis) # lr = [group['lr'] for group in optimizer_dis.param_groups] # print(f'Discriminator lr: {lr[0]}') # net.train() if epoch % args.validation_epoch == 0 and epoch > 0: cfg.gan_eval = False dis_net.eval() compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) # Compute validation mAP after training is finished compute_validation_map(epoch, iteration, yolact_net, val_dataset, log if args.log else None) except KeyboardInterrupt: if args.interrupt: print('Stopping early. Saving network...') # Delete previous copy of the interrupted network so we don't spam the weights folder SavePath.remove_interrupt(args.save_folder) yolact_net.save_weights(save_path(epoch, repr(iteration) + '_interrupt')) exit() yolact_net.save_weights(save_path(epoch, iteration))