def train(args, pt_dir, chkpt_path, trainloader, devloader, writer, logger, hp, hp_str): model = get_SLOCountNet(hp).cuda() print("FOV: {}", model.get_fov(hp.features.n_fft)) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("N_parameters : {}".format(params)) model = DataParallel(model) if hp.train.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam) else: raise Exception("%s optimizer not supported" % hp.train.optimizer) epoch = 0 best_loss = np.inf if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['step'] # will use new given hparams. if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint.") else: logger.info("Starting new training run") try: for epoch in range(epoch, hp.train.n_epochs): vad_scores = Binarymetrics.BinaryMeter() # activity scores vod_scores = Binarymetrics.BinaryMeter() # overlap scores count_scores = Binarymetrics.MultiMeter() # Countnet scores model.train() tot_loss = 0 with tqdm(trainloader) as t: t.set_description("Epoch: {}".format(epoch)) for count, batch in enumerate(trainloader): features, labels = batch features = features.cuda() labels = labels.cuda() preds = model(features) loss = criterion(preds, labels) optimizer.zero_grad() loss.backward() optimizer.step() # compute proper metrics for VAD loss = loss.item() if loss > 1e8 or math.isnan(loss): # check if exploded logger.error("Loss exploded to %.02f at step %d!" % (loss, epoch)) raise Exception("Loss exploded") VADpreds = torch.sum(torch.exp(preds[:, 1:5, :]), dim=1).unsqueeze(1) VADlabels = torch.sum(labels[:, 1:5, :], dim=1).unsqueeze(1) vad_scores.update(VADpreds, VADlabels) VODpreds = torch.sum(torch.exp(preds[:, 2:5, :]), dim=1).unsqueeze(1) VODlabels = torch.sum(labels[:, 2:5, :], dim=1).unsqueeze(1) vod_scores.update(VODpreds, VODlabels) count_scores.update( torch.argmax(torch.exp(preds), 1).unsqueeze(1), torch.argmax(labels, 1).unsqueeze(1)) tot_loss += loss vad_fa = vad_scores.get_fa().item() vad_miss = vad_scores.get_miss().item() vad_precision = vad_scores.get_precision().item() vad_recall = vad_scores.get_recall().item() vad_matt = vad_scores.get_matt().item() vad_f1 = vad_scores.get_f1().item() vad_tp = vad_scores.tp.item() vad_tn = vad_scores.tn.item() vad_fp = vad_scores.fp.item() vad_fn = vad_scores.fn.item() vod_fa = vod_scores.get_fa().item() vod_miss = vod_scores.get_miss().item() vod_precision = vod_scores.get_precision().item() vod_recall = vod_scores.get_recall().item() vod_matt = vod_scores.get_matt().item() vod_f1 = vod_scores.get_f1().item() vod_tp = vod_scores.tp.item() vod_tn = vod_scores.tn.item() vod_fp = vod_scores.fp.item() vod_fn = vod_scores.fn.item() count_fa = count_scores.get_accuracy().item() count_miss = count_scores.get_miss().item() count_precision = count_scores.get_precision().item() count_recall = count_scores.get_recall().item() count_matt = count_scores.get_matt().item() count_f1 = count_scores.get_f1().item() count_tp = count_scores.get_tp().item() count_tn = count_scores.get_tn().item() count_fp = count_scores.get_fp().item() count_fn = count_scores.get_fn().item() t.set_postfix(loss=tot_loss / (count + 1), vad_miss=vad_miss, vad_fa=vad_fa, vad_prec=vad_precision, vad_recall=vad_recall, vad_matt=vad_matt, vad_f1=vad_f1, vod_miss=vod_miss, vod_fa=vod_fa, vod_prec=vod_precision, vod_recall=vod_recall, vod_matt=vod_matt, vod_f1=vod_f1, count_miss=count_miss, count_fa=count_fa, count_prec=count_precision, count_recall=count_recall, count_matt=count_matt, count_f1=count_f1) t.update() writer.log_metrics("train_vad", loss, vad_fa, vad_miss, vad_recall, vad_precision, vad_f1, vad_matt, vad_tp, vad_tn, vad_fp, vad_fn, epoch) writer.log_metrics("train_vod", loss, vod_fa, vod_miss, vod_recall, vod_precision, vod_f1, vod_matt, vod_tp, vod_tn, vod_fp, vod_fn, epoch) writer.log_metrics("train_count", loss, count_fa, count_miss, count_recall, count_precision, count_f1, count_matt, count_tp, count_tn, count_fp, count_fn, epoch) # end epoch save model and validate it val_loss = validate(hp, model, devloader, writer, epoch) if hp.train.save_best == 0: save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': epoch, 'hp_str': hp_str, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) else: if val_loss < best_loss: # save only when best best_loss = val_loss save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': epoch, 'hp_str': hp_str, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) return best_loss except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()
class SSRunner(object): def __init__(self, config): self.config = config # Data self.dataset_ss_train, _, self.dataset_ss_val = DatasetUtil.get_dataset_by_type( DatasetUtil.dataset_type_ss, self.config.ss_size, is_balance=self.config.is_balance_data, data_root=self.config.data_root_path, train_label_path=self.config.label_path, max_size=self.config.max_size) self.data_loader_ss_train = DataLoader(self.dataset_ss_train, self.config.ss_batch_size, True, num_workers=16, drop_last=True) self.data_loader_ss_val = DataLoader(self.dataset_ss_val, self.config.ss_batch_size, False, num_workers=16, drop_last=True) # Model self.net = self.config.Net(num_classes=self.config.ss_num_classes, output_stride=self.config.output_stride, arch=self.config.arch) if self.config.only_train_ss: self.net = BalancedDataParallel(0, self.net, dim=0).cuda() else: self.net = DataParallel(self.net).cuda() pass cudnn.benchmark = True # Optimize self.optimizer = optim.SGD(params=[ { 'params': self.net.module.model.backbone.parameters(), 'lr': self.config.ss_lr }, { 'params': self.net.module.model.classifier.parameters(), 'lr': self.config.ss_lr * 10 }, ], lr=self.config.ss_lr, momentum=0.9, weight_decay=1e-4) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.config.ss_milestones, gamma=0.1) # Loss self.ce_loss = nn.CrossEntropyLoss(ignore_index=255, reduction='mean').cuda() pass def train_ss(self, start_epoch=0, model_file_name=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) self.load_model(model_file_name) pass # self.eval_ss(epoch=0) best_iou = 0.0 for epoch in range(start_epoch, self.config.ss_epoch_num): Tools.print() Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format( epoch, self.optimizer.param_groups[0]['lr'], self.optimizer.param_groups[1]['lr']), txt_path=self.config.ss_save_result_txt) ########################################################################### # 1 训练模型 all_loss = 0.0 self.net.train() if self.config.is_balance_data: self.dataset_ss_train.reset() pass for i, (inputs, labels) in tqdm(enumerate(self.data_loader_ss_train), total=len(self.data_loader_ss_train)): inputs, labels = inputs.float().cuda(), labels.long().cuda() self.optimizer.zero_grad() result = self.net(inputs) loss = self.ce_loss(result, labels) loss.backward() self.optimizer.step() all_loss += loss.item() if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0: score = self.eval_ss(epoch=epoch) mean_iou = score["Mean IoU"] if mean_iou > best_iou: best_iou = mean_iou save_file_name = Tools.new_dir( os.path.join( self.config.ss_model_dir, "ss_{}_{}_{}.pth".format(epoch, i, best_iou))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.ss_save_result_txt) Tools.print() pass pass self.scheduler.step() ########################################################################### Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format( epoch, self.config.ss_epoch_num, all_loss / len(self.data_loader_ss_train)), txt_path=self.config.ss_save_result_txt) ########################################################################### # 2 保存模型 if epoch % self.config.ss_save_epoch_freq == 0: Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.ss_model_dir, "ss_{}.pth".format(epoch))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.ss_save_result_txt) Tools.print() pass ########################################################################### ########################################################################### # 3 评估模型 if epoch % self.config.ss_eval_epoch_freq == 0: score = self.eval_ss(epoch=epoch) pass ########################################################################### pass # Final Save Tools.print() save_file_name = Tools.new_dir( os.path.join(self.config.ss_model_dir, "ss_final_{}.pth".format(self.config.ss_epoch_num))) torch.save(self.net.state_dict(), save_file_name) Tools.print("Save Model to {}".format(save_file_name), txt_path=self.config.ss_save_result_txt) Tools.print() self.eval_ss(epoch=self.config.ss_epoch_num) pass def eval_ss(self, epoch=0, model_file_name=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) self.load_model(model_file_name) pass self.net.eval() metrics = StreamSegMetrics(self.config.ss_num_classes) with torch.no_grad(): for i, (inputs, labels) in tqdm(enumerate(self.data_loader_ss_val), total=len(self.data_loader_ss_val)): inputs = inputs.float().cuda() labels = labels.long().cuda() outputs = self.net(inputs) preds = outputs.detach().max(dim=1)[1].cpu().numpy() targets = labels.cpu().numpy() metrics.update(targets, preds) pass pass score = metrics.get_results() Tools.print("{} {}".format(epoch, metrics.to_str(score)), txt_path=self.config.ss_save_result_txt) return score def inference_ss(self, model_file_name=None, data_loader=None, save_path=None): if model_file_name is not None: Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) self.load_model(model_file_name) pass final_save_path = Tools.new_dir("{}_final".format(save_path)) self.net.eval() metrics = StreamSegMetrics(self.config.ss_num_classes) with torch.no_grad(): for i, (inputs, labels, image_info_list) in tqdm(enumerate(data_loader), total=len(data_loader)): assert len(image_info_list) == 1 # 标签 max_size = 1000 size = Image.open(image_info_list[0]).size basename = os.path.basename(image_info_list[0]) final_name = os.path.join(final_save_path, basename.replace(".JPEG", ".png")) if os.path.exists(final_name): continue if size[0] < max_size and size[1] < max_size: targets = F.interpolate(torch.unsqueeze( labels[0].float().cuda(), dim=0), size=(size[1], size[0]), mode="nearest").detach().cpu() else: targets = F.interpolate(torch.unsqueeze(labels[0].float(), dim=0), size=(size[1], size[0]), mode="nearest") targets = targets[0].long().numpy() # 预测 outputs = 0 for input_index, input_one in enumerate(inputs): output_one = self.net(input_one.float().cuda()) if size[0] < max_size and size[1] < max_size: outputs += F.interpolate( output_one, size=(size[1], size[0]), mode="bilinear", align_corners=False).detach().cpu() else: outputs += F.interpolate(output_one.detach().cpu(), size=(size[1], size[0]), mode="bilinear", align_corners=False) pass pass outputs = outputs / len(inputs) preds = outputs.max(dim=1)[1].numpy() # 计算 metrics.update(targets, preds) if save_path: Image.open(image_info_list[0]).save( os.path.join(save_path, basename)) DataUtil.gray_to_color( np.asarray(targets[0], dtype=np.uint8)).save( os.path.join(save_path, basename.replace(".JPEG", "_l.png"))) DataUtil.gray_to_color(np.asarray( preds[0], dtype=np.uint8)).save( os.path.join(save_path, basename.replace(".JPEG", ".png"))) Image.fromarray(np.asarray( preds[0], dtype=np.uint8)).save(final_name) pass pass pass score = metrics.get_results() Tools.print("{}".format(metrics.to_str(score)), txt_path=self.config.ss_save_result_txt) return score def load_model(self, model_file_name): Tools.print("Load model form {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) checkpoint = torch.load(model_file_name) if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1: # checkpoint = {key.replace("module.", ""): checkpoint[key] for key in checkpoint} pass self.net.load_state_dict(checkpoint, strict=True) Tools.print("Restore from {}".format(model_file_name), txt_path=self.config.ss_save_result_txt) pass def stat(self): stat(self.net, (3, self.config.ss_size, self.config.ss_size)) pass pass
def stage1_train(args): logger = init_logger(args) if args.summary: summary_writer = SummaryWriter(args.s1_summary_path) dataset = Birds(args.data_dir, split='train', im_size=64) dataloader = DataLoader(dataset, batch_size=args.s1_batch_size, shuffle=True, num_workers=8, drop_last=True) generator = Stage1Generator(args.txt_embedding_dim, args.c_dim, args.z_dim, args.gf_dim).cuda() print('generator={}'.format(generator)) discriminator = Stage1Discriminator(args.df_dim, args.c_dim).cuda() print('discriminator={}'.format(discriminator)) device_ids = list(range(torch.cuda.device_count())) generator = DataParallel(generator, device_ids) discriminator = DataParallel(discriminator, device_ids) g_parameters = list(filter(lambda f: f.requires_grad, generator.parameters())) d_parameters = list(filter(lambda f: f.requires_grad, discriminator.parameters())) g_optimizer = torch.optim.Adam(g_parameters, args.lr, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(d_parameters, args.lr, betas=(0.5, 0.999)) r_labels = torch.ones((args.s1_batch_size,), device='cuda:0') f_labels = torch.zeros((args.s1_batch_size,), device='cuda:0') criterion = nn.BCELoss() cur_lr = args.lr for epoch in range(args.total_epoch): for idx, (r_imgs, txt_embeddings) in enumerate(dataloader): r_imgs = r_imgs.cuda() txt_embeddings = txt_embeddings.cuda() # discriminator noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_() x, mu, logvar = generator(txt_embeddings, noise) d_loss, r_loss, w_loss, f_loss = discriminator_loss(discriminator, r_imgs, x.detach(), mu.detach(), r_labels, f_labels, criterion) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # generator noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_() x, mu, logvar = generator(txt_embeddings, noise) logits = discriminator(mu.detach(), x) g_loss = criterion(logits, r_labels) kl_loss_ = kl_loss(mu, logvar) g_loss += kl_loss_ g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if args.summary and idx % args.summary_iters == 0 and idx > 0: summary_writer.add_scalar('d_loss', g_loss.item()) summary_writer.add_scalar('r_loss', r_loss.item()) summary_writer.add_scalar('w_loss', w_loss.item()) summary_writer.add_scalar('f_loss', f_loss.item()) summary_writer.add_scalar('g_loss', g_loss.item()) summary_writer.add_scalar('kl_loss', kl_loss.item()) if epoch % args.lr_decay_every_epoch == 0 and epoch > 0: logger.info(f'lr decay: {cur_lr}') cur_lr *= args.lr_decay_ratio g_optimizer = torch.optim.Adam(g_parameters, cur_lr, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(d_parameters, cur_lr, betas=(0.5, 0.999)) if epoch % args.display_epoch == 0 and epoch > 0: logger.info(f'epoch:{epoch}, lr={cur_lr}, d_loss={d_loss}, r_loss={r_loss}, w_loss={w_loss}, f_loss={f_loss}, g_loss={g_loss}, kl_loss={kl_loss_}') if epoch % args.checkpoint_epoch == 0 and epoch > 0: if not os.path.isdir(args.s1_checkpoint_dir): os.makedirs(args.s1_checkpoint_dir) logger.info(f'saving checkpoints_{epoch}') torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'generator_epoch_{epoch}.pth')) torch.save(discriminator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'discriminator_epoch_{epoch}.pth')) torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'generator.pth')) torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'discriminator.pth')) if args.summary: summary_writer.close()