def train_model(model, model_name, hyperparams, device, epochs): ''' Train Model This is a generic function to call the model's training function. ''' print('Beginning Training for: ', model_name) print('------------------------------------') results = {} if torch.cuda.device_count() > 1: print("Using ", torch.cuda.device_count(), " GPUs.") print('------------------------------------') model = DataParallel(model) model = model.to(device=device) optimizer = optim.Adam(model.parameters(), betas=hyperparams['betas'], lr=hyperparams['learning_rate'], weight_decay=hyperparams['L2_reg']) lr_updater = lr_scheduler.StepLR(optimizer, hyperparams['lr_decay_epochs'], hyperparams['lr_decay']) results = train(model, optimizer, lr_updater, results, epochs=epochs) plot_results(results, model_name ,save=True) np.save(model_name, results) return results
def main(): train_dataset = hrb_input.TrainDataset() val_dataset = hrb_input.ValidationDataset() test_dataset = hrb_input.TestDataset() train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=8, pin_memory=True) model = network.resnet50(num_classes=1000) model = DataParallel(model, device_ids=[0, 1, 2]) model.to(device) loss_func = F.cross_entropy optimizer = optim.Adam(model.parameters(), lr=LR) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=LR_STEPS, gamma=LR_GAMMA) best_loss = 100.0 patience_counter = 0 for epoch in range(EPOCHS): train(model, optimizer, loss_func, train_loader, epoch) val_loss = validate(model, loss_func, val_loader) if val_loss < best_loss: torch.save(model, MODEL_PATH) print('Saving improved model') print() best_loss = val_loss patience_counter = 0 else: patience_counter += 1 print('Epoch(s) since best model: ', patience_counter) print() if patience_counter >= EARLY_STOPPING_EPOCHS: print('Early Stopping ...') print() break scheduler.step() print('Predicting labels from best trained model') predict(test_loader)
class nnUNetTrainerV2CascadeFullRes_DP(nnUNetTrainerV2CascadeFullRes): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, num_gpus=1, distribute_batch_size=False, fp16=False, previous_trainer="nnUNetTrainerV2_DP"): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, previous_trainer, fp16) self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, num_gpus, distribute_batch_size, fp16, previous_trainer) self.num_gpus = num_gpus self.distribute_batch_size = distribute_batch_size self.dice_do_BG = False self.dice_smooth = 1e-5 if self.output_folder is not None: task = self.output_folder.split("/")[-3] plans_identifier = self.output_folder.split("/")[-2].split( "__")[-1] folder_with_segs_prev_stage = join( network_training_output_dir, "3d_lowres", task, previous_trainer + "__" + plans_identifier, "pred_next_stage") self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage else: self.folder_with_segs_from_prev_stage = None print(self.folder_with_segs_from_prev_stage) def get_basic_generators(self): self.load_dataset() self.do_split() if self.threeD: dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self. oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True, oversample_foreground_percent=self. oversample_foreground_percent, pad_mode="constant", pad_sides=self.pad_all_sides) else: raise NotImplementedError("2D has no cascade") return dl_tr, dl_val def process_plans(self, plans): super().process_plans(plans) if not self.distribute_batch_size: self.batch_size = self.num_gpus * self.plans['plans_per_stage'][ self.stage]['batch_size'] else: if self.batch_size < self.num_gpus: print( "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well" ) elif self.batch_size % self.num_gpus != 0: print( "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well" ) def initialize(self, training=True, force_load_plans=False): if not self.was_initialized: if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here we wrap the loss for deep supervision ############ net_numpool = len(self.net_num_pool_op_kernel_sizes) weights = np.array([1 / (2**i) for i in range(net_numpool)]) mask = np.array([ True if i < net_numpool - 1 else False for i in range(net_numpool) ]) weights[~mask] = 0 weights = weights / weights.sum() self.loss_weights = weights ################# END ################### self.folder_with_preprocessed_data = join( self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: if not isdir(self.folder_with_segs_from_prev_stage): raise RuntimeError( "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the " "segmentations for the next stage") self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, DataParallel)) else: self.print_to_log_file( 'self.was_initialized is True, not running self.initialize again' ) self.was_initialized = True def initialize_network(self): """ replace genericUNet with the implementation of above for super speeds """ if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet_DP( self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper def run_training(self): self.maybe_update_lr(self.epoch) # amp must be initialized before DP ds = self.network.do_ds self.network.do_ds = True # self.network = DataParallel(self.network, tuple(range(self.num_gpus)), ) self.network = DataParallel(self.network, device_ids=list(range(0, self.num_gpus))) ret = nnUNetTrainer.run_training(self) self.network = self.network.module self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: l.backward() clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return l.detach().cpu().numpy() def run_online_evaluation(self, tp_hard, fp_hard, fn_hard): tp_hard = tp_hard.detach().cpu().numpy().mean(0) fp_hard = fp_hard.detach().cpu().numpy().mean(0) fn_hard = fn_hard.detach().cpu().numpy().mean(0) self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def compute_loss(self, ces, tps, fps, fns): loss = None for i in range(len(ces)): if not self.dice_do_BG: tp = tps[i][:, 1:] fp = fps[i][:, 1:] fn = fns[i][:, 1:] else: tp = tps[i] fp = fps[i] fn = fns[i] if self.batch_dice: tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) else: pass nominator = 2 * tp + self.dice_smooth denominator = 2 * tp + fp + fn + self.dice_smooth dice_loss = (-nominator / denominator).mean() if loss is None: loss = self.loss_weights[i] * (ces[i].mean() + dice_loss) else: loss += self.loss_weights[i] * (ces[i].mean() + dice_loss) return loss
class nnUNetTrainerV2_DP(nnUNetTrainerV2): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, num_gpus=1, distribute_batch_size=False, fp16=False): super(nnUNetTrainerV2_DP, self).__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, fp16) self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, num_gpus, distribute_batch_size, fp16) self.num_gpus = num_gpus self.distribute_batch_size = distribute_batch_size self.dice_smooth = 1e-5 self.dice_do_BG = False self.loss = None self.loss_weights = None def setup_DA_params(self): super(nnUNetTrainerV2_DP, self).setup_DA_params() self.data_aug_params['num_threads'] = 8 * self.num_gpus def process_plans(self, plans): super(nnUNetTrainerV2_DP, self).process_plans(plans) if not self.distribute_batch_size: self.batch_size = self.num_gpus * self.plans['plans_per_stage'][ self.stage]['batch_size'] else: if self.batch_size < self.num_gpus: print( "WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well" ) elif self.batch_size % self.num_gpus != 0: print( "WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well" ) def initialize(self, training=True, force_load_plans=False): """ - replaced get_default_augmentation with get_moreDA_augmentation - only run this code once - loss function wrapper for deep supervision :param training: :param force_load_plans: :return: """ if not self.was_initialized: os.makedirs(self.output_folder, exist_ok=True) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() ################# Here configure the loss for deep supervision ############ net_numpool = len(self.net_num_pool_op_kernel_sizes) weights = np.array([1 / (2**i) for i in range(net_numpool)]) mask = np.array([ True if i < net_numpool - 1 else False for i in range(net_numpool) ]) weights[~mask] = 0 weights = weights / weights.sum() self.loss_weights = weights ################# END ################### self.folder_with_preprocessed_data = join( self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") self.tr_gen, self.val_gen = get_moreDA_augmentation( self.dl_tr, self.dl_val, self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, pin_memory=self.pin_memory) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() assert isinstance(self.network, (SegmentationNetwork, DataParallel)) else: self.print_to_log_file( 'self.was_initialized is True, not running self.initialize again' ) self.was_initialized = True def initialize_network(self): """ replace genericUNet with the implementation of above for super speeds """ if self.threeD: conv_op = nn.Conv3d dropout_op = nn.Dropout3d norm_op = nn.InstanceNorm3d else: conv_op = nn.Conv2d dropout_op = nn.Dropout2d norm_op = nn.InstanceNorm2d norm_op_kwargs = {'eps': 1e-5, 'affine': True} dropout_op_kwargs = {'p': 0, 'inplace': True} net_nonlin = nn.LeakyReLU net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} self.network = Generic_UNet_DP( self.num_input_channels, self.base_num_features, self.num_classes, len(self.net_num_pool_op_kernel_sizes), self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, net_nonlin, net_nonlin_kwargs, True, False, InitWeights_He(1e-2), self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True) if torch.cuda.is_available(): self.network.cuda() self.network.inference_apply_nonlin = softmax_helper def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None def run_training(self): self.maybe_update_lr(self.epoch) # amp must be initialized before DP ds = self.network.do_ds self.network.do_ds = True self.network = DataParallel( self.network, tuple(range(self.num_gpus)), ) ret = nnUNetTrainer.run_training(self) self.network = self.network.module self.network.do_ds = ds return ret def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data) target = to_cuda(target) self.optimizer.zero_grad() if self.fp16: with autocast(): ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: self.amp_grad_scaler.scale(l).backward() self.amp_grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.amp_grad_scaler.step(self.optimizer) self.amp_grad_scaler.update() else: ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation) if run_online_evaluation: ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret self.run_online_evaluation(tp_hard, fp_hard, fn_hard) else: ces, tps, fps, fns = ret del data, target l = self.compute_loss(ces, tps, fps, fns) if do_backprop: l.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return l.detach().cpu().numpy() def run_online_evaluation(self, tp_hard, fp_hard, fn_hard): tp_hard = tp_hard.detach().cpu().numpy().mean(0) fp_hard = fp_hard.detach().cpu().numpy().mean(0) fn_hard = fn_hard.detach().cpu().numpy().mean(0) self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard)) def compute_loss(self, ces, tps, fps, fns): # we now need to effectively reimplement the loss loss = None for i in range(len(ces)): if not self.dice_do_BG: tp = tps[i][:, 1:] fp = fps[i][:, 1:] fn = fns[i][:, 1:] else: tp = tps[i] fp = fps[i] fn = fns[i] if self.batch_dice: tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) else: pass nominator = 2 * tp + self.dice_smooth denominator = 2 * tp + fp + fn + self.dice_smooth dice_loss = (-nominator / denominator).mean() if loss is None: loss = self.loss_weights[i] * (ces[i].mean() + dice_loss) else: loss += self.loss_weights[i] * (ces[i].mean() + dice_loss) ########### return loss
def train(args, pt_dir, chkpt_path, trainloader, devloader, writer, logger, hp, hp_str): model = get_SLOCountNet(hp).cuda() print("FOV: {}", model.get_fov(hp.features.n_fft)) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("N_parameters : {}".format(params)) model = DataParallel(model) if hp.train.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam) else: raise Exception("%s optimizer not supported" % hp.train.optimizer) epoch = 0 best_loss = np.inf if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['step'] # will use new given hparams. if hp_str != checkpoint['hp_str']: logger.warning("New hparams is different from checkpoint.") else: logger.info("Starting new training run") try: for epoch in range(epoch, hp.train.n_epochs): vad_scores = Binarymetrics.BinaryMeter() # activity scores vod_scores = Binarymetrics.BinaryMeter() # overlap scores count_scores = Binarymetrics.MultiMeter() # Countnet scores model.train() tot_loss = 0 with tqdm(trainloader) as t: t.set_description("Epoch: {}".format(epoch)) for count, batch in enumerate(trainloader): features, labels = batch features = features.cuda() labels = labels.cuda() preds = model(features) loss = criterion(preds, labels) optimizer.zero_grad() loss.backward() optimizer.step() # compute proper metrics for VAD loss = loss.item() if loss > 1e8 or math.isnan(loss): # check if exploded logger.error("Loss exploded to %.02f at step %d!" % (loss, epoch)) raise Exception("Loss exploded") VADpreds = torch.sum(torch.exp(preds[:, 1:5, :]), dim=1).unsqueeze(1) VADlabels = torch.sum(labels[:, 1:5, :], dim=1).unsqueeze(1) vad_scores.update(VADpreds, VADlabels) VODpreds = torch.sum(torch.exp(preds[:, 2:5, :]), dim=1).unsqueeze(1) VODlabels = torch.sum(labels[:, 2:5, :], dim=1).unsqueeze(1) vod_scores.update(VODpreds, VODlabels) count_scores.update( torch.argmax(torch.exp(preds), 1).unsqueeze(1), torch.argmax(labels, 1).unsqueeze(1)) tot_loss += loss vad_fa = vad_scores.get_fa().item() vad_miss = vad_scores.get_miss().item() vad_precision = vad_scores.get_precision().item() vad_recall = vad_scores.get_recall().item() vad_matt = vad_scores.get_matt().item() vad_f1 = vad_scores.get_f1().item() vad_tp = vad_scores.tp.item() vad_tn = vad_scores.tn.item() vad_fp = vad_scores.fp.item() vad_fn = vad_scores.fn.item() vod_fa = vod_scores.get_fa().item() vod_miss = vod_scores.get_miss().item() vod_precision = vod_scores.get_precision().item() vod_recall = vod_scores.get_recall().item() vod_matt = vod_scores.get_matt().item() vod_f1 = vod_scores.get_f1().item() vod_tp = vod_scores.tp.item() vod_tn = vod_scores.tn.item() vod_fp = vod_scores.fp.item() vod_fn = vod_scores.fn.item() count_fa = count_scores.get_accuracy().item() count_miss = count_scores.get_miss().item() count_precision = count_scores.get_precision().item() count_recall = count_scores.get_recall().item() count_matt = count_scores.get_matt().item() count_f1 = count_scores.get_f1().item() count_tp = count_scores.get_tp().item() count_tn = count_scores.get_tn().item() count_fp = count_scores.get_fp().item() count_fn = count_scores.get_fn().item() t.set_postfix(loss=tot_loss / (count + 1), vad_miss=vad_miss, vad_fa=vad_fa, vad_prec=vad_precision, vad_recall=vad_recall, vad_matt=vad_matt, vad_f1=vad_f1, vod_miss=vod_miss, vod_fa=vod_fa, vod_prec=vod_precision, vod_recall=vod_recall, vod_matt=vod_matt, vod_f1=vod_f1, count_miss=count_miss, count_fa=count_fa, count_prec=count_precision, count_recall=count_recall, count_matt=count_matt, count_f1=count_f1) t.update() writer.log_metrics("train_vad", loss, vad_fa, vad_miss, vad_recall, vad_precision, vad_f1, vad_matt, vad_tp, vad_tn, vad_fp, vad_fn, epoch) writer.log_metrics("train_vod", loss, vod_fa, vod_miss, vod_recall, vod_precision, vod_f1, vod_matt, vod_tp, vod_tn, vod_fp, vod_fn, epoch) writer.log_metrics("train_count", loss, count_fa, count_miss, count_recall, count_precision, count_f1, count_matt, count_tp, count_tn, count_fp, count_fn, epoch) # end epoch save model and validate it val_loss = validate(hp, model, devloader, writer, epoch) if hp.train.save_best == 0: save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': epoch, 'hp_str': hp_str, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) else: if val_loss < best_loss: # save only when best best_loss = val_loss save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % epoch) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': epoch, 'hp_str': hp_str, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) return best_loss except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()
state=trk_state) loss_j = model.module.loss(output, frame_data, verbose=print_flag) losses.append(loss_j) # print(model._cs_pos[0].weight.mean(), 'after') loss, logs, logh = parse_losses(losses) if isinstance(loss, torch.Tensor): # for name, one in model.named_parameters(): # if one.grad is not None: # print(name, one.grad.mean()) loss.backward() # for name, one in model.named_parameters(): # if one.grad is not None: # print(name, one.grad.mean()) clip_grad_norm_(model.parameters(), 1) opt.step() # print(model._cs_pos[0].weight.mean(), 'stepped') lr = lr_scheduler.get_last_lr()[0] if print_flag: print( 'epoch %d (outer %d inner %d [%d:%d]) iter %d/%d: lr %.6f' % (lr_scheduler.last_epoch, eo, e, second, third, i, len(dl), float(lr))) print('\t', end='') for k in logs: print('.%s: %.6f(%.6f), ' % (k, logs[k], logh[k]), end='') print(' loss %.6f' % float(loss)) for j in range(len(losses)): print('\tstep %d;' % j, end='')
def stage1_train(args): logger = init_logger(args) if args.summary: summary_writer = SummaryWriter(args.s1_summary_path) dataset = Birds(args.data_dir, split='train', im_size=64) dataloader = DataLoader(dataset, batch_size=args.s1_batch_size, shuffle=True, num_workers=8, drop_last=True) generator = Stage1Generator(args.txt_embedding_dim, args.c_dim, args.z_dim, args.gf_dim).cuda() print('generator={}'.format(generator)) discriminator = Stage1Discriminator(args.df_dim, args.c_dim).cuda() print('discriminator={}'.format(discriminator)) device_ids = list(range(torch.cuda.device_count())) generator = DataParallel(generator, device_ids) discriminator = DataParallel(discriminator, device_ids) g_parameters = list(filter(lambda f: f.requires_grad, generator.parameters())) d_parameters = list(filter(lambda f: f.requires_grad, discriminator.parameters())) g_optimizer = torch.optim.Adam(g_parameters, args.lr, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(d_parameters, args.lr, betas=(0.5, 0.999)) r_labels = torch.ones((args.s1_batch_size,), device='cuda:0') f_labels = torch.zeros((args.s1_batch_size,), device='cuda:0') criterion = nn.BCELoss() cur_lr = args.lr for epoch in range(args.total_epoch): for idx, (r_imgs, txt_embeddings) in enumerate(dataloader): r_imgs = r_imgs.cuda() txt_embeddings = txt_embeddings.cuda() # discriminator noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_() x, mu, logvar = generator(txt_embeddings, noise) d_loss, r_loss, w_loss, f_loss = discriminator_loss(discriminator, r_imgs, x.detach(), mu.detach(), r_labels, f_labels, criterion) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # generator noise = torch.zeros((args.s1_batch_size, args.z_dim), device='cuda:0').normal_() x, mu, logvar = generator(txt_embeddings, noise) logits = discriminator(mu.detach(), x) g_loss = criterion(logits, r_labels) kl_loss_ = kl_loss(mu, logvar) g_loss += kl_loss_ g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() if args.summary and idx % args.summary_iters == 0 and idx > 0: summary_writer.add_scalar('d_loss', g_loss.item()) summary_writer.add_scalar('r_loss', r_loss.item()) summary_writer.add_scalar('w_loss', w_loss.item()) summary_writer.add_scalar('f_loss', f_loss.item()) summary_writer.add_scalar('g_loss', g_loss.item()) summary_writer.add_scalar('kl_loss', kl_loss.item()) if epoch % args.lr_decay_every_epoch == 0 and epoch > 0: logger.info(f'lr decay: {cur_lr}') cur_lr *= args.lr_decay_ratio g_optimizer = torch.optim.Adam(g_parameters, cur_lr, betas=(0.5, 0.999)) d_optimizer = torch.optim.Adam(d_parameters, cur_lr, betas=(0.5, 0.999)) if epoch % args.display_epoch == 0 and epoch > 0: logger.info(f'epoch:{epoch}, lr={cur_lr}, d_loss={d_loss}, r_loss={r_loss}, w_loss={w_loss}, f_loss={f_loss}, g_loss={g_loss}, kl_loss={kl_loss_}') if epoch % args.checkpoint_epoch == 0 and epoch > 0: if not os.path.isdir(args.s1_checkpoint_dir): os.makedirs(args.s1_checkpoint_dir) logger.info(f'saving checkpoints_{epoch}') torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'generator_epoch_{epoch}.pth')) torch.save(discriminator.state_dict(), os.path.join(args.s1_checkpoint_dir, f'discriminator_epoch_{epoch}.pth')) torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'generator.pth')) torch.save(generator.state_dict(), os.path.join(args.s1_checkpoint_dir, 'discriminator.pth')) if args.summary: summary_writer.close()
def neuralwarp_train(**kwargs): # 多尺度图片训练 396+ print(kwargs) #print("Mask == 1") with open(kwargs['params']) as f: params = json.load(f) if kwargs['manner'] == 'train': params['is_train'] = True else: params['is_train'] = False params['batch_size'] = kwargs['batch_size'] if torch.cuda.device_count() > 1: print("-------------------Parallel_GPU_Train--------------------------") parallel = True else: print("------------------Single_GPU_Train----------------------") parallel = False opt.feature = 'cqt' opt.notes = 'SoftDTW' opt.model = 'SoftDTW' opt.batch_size = 'batch_size' os.environ["CUDA_VISIBLE_DEVICES"] = str(kwargs["Device"]) opt.Device=kwargs["Device"] #device_ids = [2] opt._parse(kwargs) model = getattr(models, opt.model)(params) p = 'check_points/' + model.model_name + opt.notes #f = os.path.join(p, "0620_07:05:30.pth")#使用Neural_dtw目前最优 0620_07:05:30.pth cover80 map:0.705113267654046 0.08125 7.96875 #f = os.path.join(p, "0620_17:37:35.pth") #f = os.path.join(p, "0621_22:42:59.pth")#NeuralDTW_Milti_Metix_res 0622_16:33:07.pth 0621_22:42:59.pth #f = os.path.join(p, "0628_17:00:52.pth")#0628_17:00:52.pth FCN #f = os.path.join(p,"0623_16:01:05.pth") #3seq #f = os.path.join(p,"0630_07:59:56.pth")#VGG11 0630_01:10:15.pth 0630_07:59:56.pth if kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation_SPP': f = os.path.join(p,"0704_19:58:25.pth") elif kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation_SPP2': f = os.path.join(p,"0709_00:31:23.pth") elif kwargs['model'] == 'NeuralDTW_CNN_Mask_dilation': f = os.path.join(p,"0704_06:40:41.pth") opt.load_model_path = f if kwargs['model'] != 'NeuralDTW' and kwargs['manner'] != 'train': if opt.load_latest is True: model.load_latest(opt.notes) elif opt.load_model_path: print("load_model:",opt.load_model_path) model.load(opt.load_model_path) if parallel == True: model = DataParallel(model) model.to(opt.device) torch.multiprocessing.set_sharing_strategy('file_system') # step2: data out_length =400 if kwargs['model'] == 'NeuralDTW_CNN_Mask_300': out_length = 300 if kwargs['model'] == 'NeuralDTW_CNN_Mask_spp': train_data0 = triplet_CQT(out_length=200, is_label=kwargs['is_label'], is_random=kwargs['is_random']) train_data1 = triplet_CQT(out_length=300, is_label=kwargs['is_label'], is_random=kwargs['is_random']) train_data2 = triplet_CQT(out_length=400, is_label=kwargs['is_label'], is_random=kwargs['is_random']) else: train_data0 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random']) train_data1 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random']) train_data2 = triplet_CQT(out_length=out_length, is_label=kwargs['is_label'], is_random=kwargs['is_random']) val_data80 = CQT('songs80', out_length=kwargs['test_length']) val_data = CQT('songs350', out_length=kwargs['test_length']) val_data_marukars = CQT('Mazurkas',out_length=kwargs['test_length']) train_dataloader0 = DataLoader(train_data0, opt.batch_size, shuffle=True, num_workers=opt.num_workers) train_dataloader1 = DataLoader(train_data1, opt.batch_size, shuffle=True, num_workers=opt.num_workers) train_dataloader2 = DataLoader(train_data2, opt.batch_size, shuffle=True, num_workers=opt.num_workers) val_dataloader80 = DataLoader(val_data80, 1, shuffle=False, num_workers=1) val_dataloader = DataLoader(val_data, 1, shuffle=False, num_workers=1) val_dataloader_marukars = DataLoader(val_data_marukars,1, shuffle=False, num_workers=1) if kwargs['manner'] == 'test': # val_slow(model, val_dataloader, style='null') val_slow_batch(model,val_dataloader_marukars, batch=100, is_dis=kwargs['zo']) elif kwargs['manner'] == 'visualize': visualize(model, val_dataloader80) elif kwargs['manner'] == 'mul_test': p = 'check_points/' + model.model_name + opt.notes l = sorted(os.listdir(p))[: 20] best_MAP, MAP = 0, 0 for f in l: f = os.path.join(p, f) model.load(f) model.to(opt.device) MAP += val_slow_batch(model, val_dataloader, batch=400, is_dis=kwargs['zo']) MAP += val_slow_batch(model, val_dataloader80, batch=400, is_dis=kwargs['zo']) if MAP > best_MAP: print('--best result--') best_MAP = MAP MAP = 0 else: # step3: criterion and optimizer be = torch.nn.BCELoss() lr = opt.lr optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay) # if parallel is True: # optimizer = torch.optim.Adam(model.module.parameters(), lr=lr, weight_decay=opt.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True, min_lr=5e-6) # step4: train best_MAP = 0 for epoch in range(opt.max_epoch): running_loss = 0 num = 0 for ii, ((a0, p0, n0, la0, lp0, ln0), (a1, p1, n1, la1, lp1, ln1), (a2, p2, n2, la2, lp2, ln2)) in tqdm( enumerate(zip(train_dataloader0, train_dataloader1, train_dataloader2))): # for ii, (a2, p2, n2) in tqdm(enumerate(train_dataloader2)): for flag in range(3): if flag == 0: a, p, n, la, lp, ln = a0, p0, n0, la0, lp0, ln0 elif flag == 1: a, p, n, la, lp, ln = a1, p1, n1, la1, lp1, ln1 else: a, p, n, la, lp, ln = a2, p2, n2, la2, lp2, ln2 B, _, _, _ = a.shape if kwargs["zo"] == True: target = torch.cat((torch.zeros(B), torch.ones(B))).cuda() else: target = torch.cat((torch.ones(B), torch.zeros(B))).cuda() # train model a = a.requires_grad_().to(opt.device) p = p.requires_grad_().to(opt.device) n = n.requires_grad_().to(opt.device) optimizer.zero_grad() pred = model(a, p, n) pred = pred.squeeze(1) loss = be(pred, target) loss.backward() optimizer.step() running_loss += loss.item() num += a.shape[0] if ii % 5000 == 0: running_loss /= num print("train_loss:",running_loss) MAP = 0 print("Youtube350:") MAP += val_slow_batch(model, val_dataloader, batch=1 , is_dis=kwargs['zo']) print("CoverSong80:") MAP += val_slow_batch(model, val_dataloader80, batch=1, is_dis=kwargs['zo']) # print("Marukars:") # MAP += val_slow_batch(model, val_dataloader_marukars, batch=100, is_dis=kwargs['zo']) if MAP > best_MAP: best_MAP = MAP print('*****************BEST*****************') if kwargs['save_model'] == True: if parallel: model.module.save(opt.notes) else: model.save(opt.notes) scheduler.step(running_loss) running_loss = 0 num = 0