class Trainer(object): """ Trainer encapsulates all the logic necessary for training the Recurrent Attention Model. All hyperparameters are provided by the user in the config file. """ def __init__(self, config, data_loader): """ Construct a new Trainer instance. Args ---- - config: object containing command line arguments. - data_loader: data iterator """ self.config = config # glimpse network params self.patch_size = config.patch_size # core network params self.num_glimpses = config.num_glimpses self.hidden_size = config.hidden_size # reinforce params self.std = config.std self.M = config.M # data params if config.is_train: self.train_loader = data_loader[0] self.valid_loader = data_loader[1] image_tmp, _ = iter(self.train_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) if 'MNIST' in config.dataset_name or config.dataset_name == 'CIFAR': self.num_train = len(self.train_loader.sampler.indices) self.num_valid = len(self.valid_loader.sampler.indices) elif config.dataset_name == 'ImageNet': # the ImageNet cannot be sampled, otherwise this part will be wrong. self.num_train = 100000 #len(train_dataset) in data_loader.py, wrong: len(self.train_loader) self.num_valid = 10000 #len(self.valid_loader) else: self.test_loader = data_loader self.num_test = len(self.test_loader.dataset) image_tmp, _ = iter(self.test_loader).next() self.image_size = (image_tmp.shape[2], image_tmp.shape[3]) # assign numer of channels and classes of images in this dataset, maybe there is more robust way if 'MNIST' in config.dataset_name: self.num_channels = 1 self.num_classes = 10 elif config.dataset_name == 'ImageNet': self.num_channels = 3 self.num_classes = 1000 elif config.dataset_name == 'CIFAR': self.num_channels = 3 self.num_classes = 10 # training params self.epochs = config.epochs self.start_epoch = 0 self.momentum = config.momentum self.lr = config.init_lr self.loss_fun_baseline = config.loss_fun_baseline self.loss_fun_action = config.loss_fun_action self.weight_decay = config.weight_decay # misc params self.use_gpu = config.use_gpu self.best = config.best self.ckpt_dir = config.ckpt_dir self.logs_dir = config.logs_dir self.best_valid_acc = 0. self.best_train_acc = 0. self.counter = 0 self.lr_patience = config.lr_patience self.train_patience = config.train_patience self.use_tensorboard = config.use_tensorboard self.resume = config.resume self.print_freq = config.print_freq self.plot_freq = config.plot_freq if config.use_gpu: self.model_name = 'ram_gpu_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) else: self.model_name = 'ram_{0}_{1}_{2}x{3}_{4}_{5:1.2f}_{6}'.format( config.PBSarray_ID, config.num_glimpses, config.patch_size, config.patch_size, config.hidden_size, config.std, config.dropout) self.plot_dir = './plots/' + self.model_name + '/' if not os.path.exists(self.plot_dir): os.makedirs(self.plot_dir, exist_ok=True) # configure tensorboard logging if self.use_tensorboard: print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir) configure(tensorboard_dir) writer = SummaryWriter(logs_dir=self.logs_dir + self.model_name) # build DRAMBUTD model self.model = RecurrentAttention(self.patch_size, self.num_channels, self.image_size, self.std, self.hidden_size, self.num_classes, config) if self.use_gpu: self.model.cuda() print('[*] Number of model parameters: {:,}'.format( sum([p.data.nelement() for p in self.model.parameters()]))) # initialize optimizer and scheduler if config.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) elif config.optimizer == 'ReduceLROnPlateau': self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', patience=self.lr_patience, weight_decay=self.weight_decay) elif config.optimizer == 'Adadelta': self.optimizer = optim.Adadelta(self.model.parameters(), weight_decay=self.weight_decay) elif config.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=3e-4, weight_decay=self.weight_decay) elif config.optimizer == 'AdaBound': self.optimizer = adabound.AdaBound(self.model.parameters(), lr=3e-4, final_lr=0.1, weight_decay=self.weight_decay) elif config.optimizer == 'Ranger': self.optimizer = Ranger(self.model.parameters(), weight_decay=self.weight_decay) def reset(self, x, SM): """ Initialize the hidden state of the core network and the location vector. This is called once every time a new minibatch `x` is introduced. """ dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor) # h_t2, l_t, SM_local_smooth = self.model.initialize(x, SM) # initialize hidden state 1 as 0 vector to avoid the directly classification from context h_t1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state1 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) cell_state2 = torch.zeros(self.batch_size, self.hidden_size).type(dtype) return h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth def train(self): """ Train the model on the training set. A checkpoint of the model is saved after each epoch and if the validation accuracy is improved upon, a separate ckpt is created for use on the test set. """ # load the most recent checkpoint if self.resume: self.load_checkpoint(best=False) print("\n[*] Train on {} samples, validate on {} samples".format( self.num_train, self.num_valid)) for epoch in range(self.start_epoch, self.epochs): print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs, self.lr)) # train for 1 epoch train_loss, train_acc = self.train_one_epoch(epoch) # evaluate on validation set valid_loss, valid_acc = self.validate(epoch) # # reduce lr if validation loss plateaus # self.scheduler.step(valid_loss) is_best_valid = valid_acc > self.best_valid_acc is_best_train = train_acc > self.best_train_acc msg1 = "train loss: {:.3f} - train acc: {:.3f} " msg2 = "- val loss: {:.3f} - val acc: {:.3f}" if is_best_train: msg1 += " [*]" if is_best_valid: self.counter = 0 msg2 += " [*]" msg = msg1 + msg2 print(msg.format(train_loss, train_acc, valid_loss, valid_acc)) # check for improvement if not is_best_valid: self.counter += 1 if self.counter > self.train_patience: print("[!] No improvement in a while, stopping training.") return self.best_valid_acc = max(valid_acc, self.best_valid_acc) self.best_train_acc = max(train_acc, self.best_train_acc) self.save_checkpoint( { 'epoch': epoch + 1, 'model_state': self.model.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': self.best_valid_acc, 'best_train_acc': self.best_train_acc, }, is_best_valid) def train_one_epoch(self, epoch): """ Train the model for 1 epoch of the training set. An epoch corresponds to one full pass through the entire training set in successive mini-batches. This is used by train() and should not be called manually. """ batch_time = AverageMeter() losses = AverageMeter() accs = AverageMeter() tic = time.time() with tqdm(total=self.num_train) as pbar: for i, (x_raw, y) in enumerate(self.train_loader): # if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) plot = False if (epoch % self.plot_freq == 0) and (i == 0): plot = True # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images imgs = [] imgs.append(x[0:9]) # extract the glimpses locs = [] log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_pi.append(p) baselines.append(b_t) locs.append(l_t[0:9]) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss # summed over timesteps and averaged across batch adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store #losses.update(loss.data[0], x.size()[0]) #accs.update(acc.data[0], x.size()[0]) losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # compute gradients and update SGD self.optimizer.zero_grad() loss.backward() self.optimizer.step() # measure elapsed time toc = time.time() batch_time.update(toc - tic) pbar.set_description( ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format( (toc - tic), loss.data.item(), acc.data.item()))) pbar.update(self.batch_size) # dump the glimpses and locs if plot: if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump( imgs, open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb")) pickle.dump( locs, open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb")) sio.savemat(self.plot_dir + "data_train_{}.mat".format(epoch + 1), mdict={ 'location': locs, 'patch': imgs }) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.train_loader) + i writer.add_scalar('Loss/train', losses, iteration) writer.add_scalar('Accuracy/train', accs, iteration) return losses.avg, accs.avg def validate(self, epoch): """ Evaluate the model on the validation set. """ losses = AverageMeter() accs = AverageMeter() for i, (x_raw, y) in enumerate(self.valid_loader): if self.use_gpu: x_raw, y = x_raw.cuda(), y.cuda() # detach images and their saliency maps x = x_raw[:, 0, ...].unsqueeze(1) SM = x_raw[:, 1, ...].unsqueeze(1) # duplicate M times x = x.repeat(self.M, 1, 1, 1) SM = SM.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # extract the glimpses log_pi = [] baselines = [] for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) # store log_pi.append(p) baselines.append(b_t) # convert list to tensors and reshape baselines = torch.stack(baselines).transpose(1, 0) log_pi = torch.stack(log_pi).transpose(1, 0) # average log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) baselines = baselines.contiguous().view(self.M, -1, baselines.shape[-1]) baselines = torch.mean(baselines, dim=0) log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1]) log_pi = torch.mean(log_pi, dim=0) # calculate reward predicted = torch.max(log_probas, 1)[1] if self.loss_fun_baseline == 'cross_entropy': # cross_entroy_loss need a long, batch x 1 tensor as target but R # also need to be subtracted by the baseline whose size is N x num_glimpse R = (predicted.detach() == y).long() # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) R = R.float().unsqueeze(1).repeat(1, self.num_glimpses) else: R = (predicted.detach() == y).float() R = R.unsqueeze(1).repeat(1, self.num_glimpses) # compute losses for differentiable modules loss_action, loss_baseline = self.choose_loss_fun( log_probas, y, baselines, R) # compute losses for differentiable modules # loss_action = F.nll_loss(log_probas, y) # loss_baseline = F.mse_loss(baselines, R) # compute reinforce loss adjusted_reward = R - baselines.detach() loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1) loss_reinforce = torch.mean(loss_reinforce, dim=0) # sum up into a hybrid loss loss = loss_action + loss_baseline + loss_reinforce # compute accuracy correct = (predicted == y).float() acc = 100 * (correct.sum() / len(y)) # store losses.update(loss.data.item(), x.size()[0]) accs.update(acc.data.item(), x.size()[0]) # log to tensorboard if self.use_tensorboard: iteration = epoch * len(self.valid_loader) + i writer.add_scalar('Accuracy/valid', accs, iteration) writer.add_scalar('Loss/valid', losses, iteration) return losses.avg, accs.avg def choose_loss_fun(self, log_probas, y, baselines, R): """ use disctionary to save function handle replacement of swith-case be careful of the argument data type and shape!!! """ loss_fun_pool = { 'mse': F.mse_loss, 'l1': F.l1_loss, 'nll': F.nll_loss, 'smooth_l1': F.smooth_l1_loss, 'kl_div': F.kl_div, 'cross_entropy': F.cross_entropy } return loss_fun_pool[self.loss_fun_action]( log_probas, y), loss_fun_pool[self.loss_fun_baseline](baselines, R) def test(self): """ Test the model on the held-out test data. This function should only be called at the very end once the model has finished training. """ correct = 0 # load the best checkpoint self.load_checkpoint(best=self.best) for i, (x, y) in enumerate(self.test_loader): if self.use_gpu: x, y = x.cuda(), y.cuda() x, y = Variable(x, volatile=True), Variable(y) # duplicate 10 times x = x.repeat(self.M, 1, 1, 1) # initialize location vector and hidden state self.batch_size = x.shape[0] h_t1, h_t2, l_t, cell_state1, cell_state2, SM_local_smooth = self.reset( x, SM) # save images and glimpse location locs = [] imgs = [] imgs.append(x[0:9]) for t in range(self.num_glimpses - 1): # forward pass through model h_t1, h_t2, l_t, b_t, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth) # store locs.append(l_t[0:9]) baselines.append(b_t) log_pi.append(p) # last iteration h_t1, h_t2, l_t, b_t, log_probas, p, cell_state1, cell_state2, SM_local_smooth = self.model( x, l_t, h_t1, h_t2, cell_state1, cell_state2, SM, SM_local_smooth, last=True) log_probas = log_probas.view(self.M, -1, log_probas.shape[-1]) log_probas = torch.mean(log_probas, dim=0) pred = log_probas.data.max(1, keepdim=True)[1] correct += pred.eq(y.data.view_as(pred)).cpu().sum() # dump test data if self.use_gpu: imgs = [g.cpu().data.numpy().squeeze() for g in imgs] locs = [l.cpu().data.numpy() for l in locs] else: imgs = [g.data.numpy().squeeze() for g in imgs] locs = [l.data.numpy() for l in locs] pickle.dump(imgs, open(self.plot_dir + "g_test.p", "wb")) pickle.dump(locs, open(self.plot_dir + "l_test.p", "wb")) sio.savemat(self.plot_dir + "test_transient.mat", mdict={'location': locs}) perc = (100. * correct) / (self.num_test) error = 100 - perc print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format( correct, self.num_test, perc, error)) def save_checkpoint(self, state, is_best): """ Save a copy of the model so that it can be loaded at a future date. This function is used when the model is being evaluated on the test data. If this model has reached the best validation accuracy thus far, a seperate file with the suffix `best` is created. """ # print("[*] Saving model to {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) torch.save(state, ckpt_path) if is_best: filename = self.model_name + '_model_best.pth.tar' shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, best=False): """ Load the best copy of a model. This is useful for 2 cases: - Resuming training with the most recent model checkpoint. - Loading the best validation model to evaluate on the test data. Params ------ - best: if set to True, loads the best model. Use this if you want to evaluate your model on the test data. Else, set to False in which case the most recent version of the checkpoint is used. """ print("[*] Loading model from {}".format(self.ckpt_dir)) filename = self.model_name + '_ckpt.pth.tar' if best: filename = self.model_name + '_model_best.pth.tar' ckpt_path = os.path.join(self.ckpt_dir, filename) ckpt = torch.load(ckpt_path) # load variables from checkpoint self.start_epoch = ckpt['epoch'] self.best_valid_acc = ckpt['best_valid_acc'] self.model.load_state_dict(ckpt['model_state']) self.optimizer.load_state_dict(ckpt['optim_state']) if best: print("[*] Loaded {} checkpoint @ epoch {} " "with best valid acc of {:.3f}".format( filename, ckpt['epoch'], ckpt['best_valid_acc'])) else: print("[*] Loaded {} checkpoint @ epoch {}".format( filename, ckpt['epoch']))
plt.plot(x, lossx, label='loss') #plt.plot(x,rmsex,label='rmse') plt.legend() #changepoint 方便查看tensorboard太麻烦 plt.savefig( '/media/workdir/hujh/hujh-new/huaweirader_baseline/log/demolog/predrnnloss.png' ) plt.close(1) ################################################################################# #changepoint if ind % 100 == 0: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, save_dir=save_dir, filename='predrnncheckpoint.pth.tar') ################# valid ######################################################## #if ind % 1000 ==0 and ind > 0: val_compareloss = [] hss = [] #model.eval() if False: with torch.no_grad(): val_rmse = AverageMeter() val_losses = AverageMeter() if False:
def main(args): """ The main training function. Only works for single node (be it single or multi-GPU) Parameters ---------- args : Parsed arguments """ # setup ngpus = torch.cuda.device_count() if ngpus == 0: raise RuntimeWarning("This will not be able to run on CPU only") print(f"Working with {ngpus} GPUs") if args.optim.lower() == "ranger": # No warm up if ranger optimizer args.warm = 0 current_experiment_time = datetime.now().strftime('%Y%m%d_%T').replace(":", "") args.exp_name = f"{'debug_' if args.debug else ''}{current_experiment_time}_" \ f"_fold{args.fold if not args.full else 'FULL'}" \ f"_{args.arch}_{args.width}" \ f"_batch{args.batch_size}" \ f"_optim{args.optim}" \ f"_{args.optim}" \ f"_lr{args.lr}-wd{args.weight_decay}_epochs{args.epochs}_deepsup{args.deep_sup}" \ f"_{'fp16' if not args.no_fp16 else 'fp32'}" \ f"_warm{args.warm}_" \ f"_norm{args.norm_layer}{'_swa' + str(args.swa_repeat) if args.swa else ''}" \ f"_dropout{args.dropout}" \ f"_warm_restart{args.warm_restart}" \ f"{'_' + args.com.replace(' ', '_') if args.com else ''}" args.save_folder = pathlib.Path(f"./runs/{args.exp_name}") args.save_folder.mkdir(parents=True, exist_ok=True) args.seg_folder = args.save_folder / "segs" args.seg_folder.mkdir(parents=True, exist_ok=True) args.save_folder = args.save_folder.resolve() save_args(args) t_writer = SummaryWriter(str(args.save_folder)) # Create model print(f"Creating {args.arch}") model_maker = getattr(models, args.arch) model = model_maker( 4, 3, width=args.width, deep_supervision=args.deep_sup, norm_layer=get_norm_layer(args.norm_layer), dropout=args.dropout) print(f"total number of trainable parameters {count_parameters(model)}") if args.swa: # Create the average model swa_model = model_maker( 4, 3, width=args.width, deep_supervision=args.deep_sup, norm_layer=get_norm_layer(args.norm_layer)) for param in swa_model.parameters(): param.detach_() swa_model = swa_model.cuda() swa_model_optim = WeightSWA(swa_model) if ngpus > 1: model = torch.nn.DataParallel(model).cuda() else: model = model.cuda() print(model) model_file = args.save_folder / "model.txt" with model_file.open("w") as f: print(model, file=f) criterion = EDiceLoss().cuda() metric = criterion.metric print(metric) rangered = False # needed because LR scheduling scheme is different for this optimizer if args.optim == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=1e-4) elif args.optim == "sgd": optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=True) elif args.optim == "adamw": print(f"weight decay argument will not be used. Default is 11e-2") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) elif args.optim == "ranger": optimizer = Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) rangered = True # optionally resume from a checkpoint if args.resume: reload_ckpt(args, model, optimizer) if args.debug: args.epochs = 2 args.warm = 0 args.val = 1 if args.full: train_dataset, bench_dataset = get_datasets(args.seed, args.debug, full=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) bench_loader = torch.utils.data.DataLoader( bench_dataset, batch_size=1, num_workers=args.workers) else: train_dataset, val_dataset, bench_dataset = get_datasets(args.seed, args.debug, fold_number=args.fold) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=max(1, args.batch_size // 2), shuffle=False, pin_memory=False, num_workers=args.workers, collate_fn=determinist_collate) bench_loader = torch.utils.data.DataLoader( bench_dataset, batch_size=1, num_workers=args.workers) print("Val dataset number of batch:", len(val_loader)) print("Train dataset number of batch:", len(train_loader)) # create grad scaler scaler = GradScaler() # Actual Train loop best = np.inf print("start warm-up now!") if args.warm != 0: tot_iter_train = len(train_loader) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda cur_iter: (1 + cur_iter) / (tot_iter_train * args.warm)) patients_perf = [] if not args.resume: for epoch in range(args.warm): ts = time.perf_counter() model.train() training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, scaler, scheduler, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") # Validate at the end of epoch every val step if (epoch + 1) % args.val == 0 and not args.full: model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch) if args.warm_restart: print('Total number of epochs should be divisible by 30, else it will do odd things') scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 30, eta_min=1e-7) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs + 30 if not rangered else round( args.epochs * 0.5)) print("start training now!") if args.swa: # c = 15, k=3, repeat = 5 c, k, repeat = 30, 3, args.swa_repeat epochs_done = args.epochs reboot_lr = 0 if args.debug: c, k, repeat = 2, 1, 2 for epoch in range(args.start_epoch + args.warm, args.epochs + args.warm): try: # do_epoch for one epoch ts = time.perf_counter() model.train() training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, scaler, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") # Validate at the end of epoch every val step if (epoch + 1) % args.val == 0 and not args.full: model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16, patients_perf=patients_perf) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, epoch) if validation_loss < best: best = validation_loss model_dict = model.state_dict() save_checkpoint( dict( epoch=epoch, arch=args.arch, state_dict=model_dict, optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict(), ), save_folder=args.save_folder, ) ts = time.perf_counter() print(f"Val epoch done in {ts - te} s") if args.swa: if (args.epochs - epoch - c) == 0: reboot_lr = optimizer.param_groups[0]['lr'] if not rangered: scheduler.step() print("scheduler stepped!") else: if epoch / args.epochs > 0.5: scheduler.step() print("scheduler stepped!") except KeyboardInterrupt: print("Stopping training loop, doing benchmark") break if args.swa: swa_model_optim.update(model) print("SWA Model initialised!") for i in range(repeat): optimizer = torch.optim.Adam(model.parameters(), args.lr / 2, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, c + 10) for swa_epoch in range(c): # do_epoch for one epoch ts = time.perf_counter() model.train() swa_model.train() current_epoch = epochs_done + i * c + swa_epoch training_loss = step(train_loader, model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, scaler, no_fp16=args.no_fp16, patients_perf=patients_perf) te = time.perf_counter() print(f"Train Epoch done in {te - ts} s") t_writer.add_scalar(f"SummaryLoss/train", training_loss, current_epoch) # update every k epochs and val: print(f"cycle number: {i}, swa_epoch: {swa_epoch}, total_cycle_to_do {repeat}") if (swa_epoch + 1) % k == 0: swa_model_optim.update(model) if not args.full: model.eval() swa_model.eval() with torch.no_grad(): validation_loss = step(val_loader, model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, save_folder=args.save_folder, no_fp16=args.no_fp16) swa_model_loss = step(val_loader, swa_model, criterion, metric, args.deep_sup, optimizer, current_epoch, t_writer, swa=True, save_folder=args.save_folder, no_fp16=args.no_fp16) t_writer.add_scalar(f"SummaryLoss/val", validation_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/swa", swa_model_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/overfit", validation_loss - training_loss, current_epoch) t_writer.add_scalar(f"SummaryLoss/overfit_swa", swa_model_loss - training_loss, current_epoch) scheduler.step() epochs_added = c * repeat save_checkpoint( dict( epoch=args.epochs + epochs_added, arch=args.arch, state_dict=swa_model.state_dict(), optimizer=optimizer.state_dict() ), save_folder=args.save_folder, ) else: save_checkpoint( dict( epoch=args.epochs, arch=args.arch, state_dict=model.state_dict(), optimizer=optimizer.state_dict() ), save_folder=args.save_folder, ) try: df_individual_perf = pd.DataFrame.from_records(patients_perf) print(df_individual_perf) df_individual_perf.to_csv(f'{str(args.save_folder)}/patients_indiv_perf.csv') reload_ckpt_bis(f'{str(args.save_folder)}/model_best.pth.tar', model) generate_segmentations(bench_loader, model, t_writer, args) except KeyboardInterrupt: print("Stopping right now!")
def main(args, logger): writer = SummaryWriter(args.subTensorboardDir) model = Vgg().to(device) trainSet = Lung(rootDir=args.dataDir, mode='train', size=args.inputSize) valSet = Lung(rootDir=args.dataDir, mode='test', size=args.inputSize) trainDataloader = DataLoader(trainSet, batch_size=args.batchSize, drop_last=True, shuffle=True, pin_memory=False, num_workers=args.numWorkers) valDataloader = DataLoader(valSet, batch_size=args.valBatchSize, drop_last=False, shuffle=False, pin_memory=False, num_workers=args.numWorkers) criterion = nn.CrossEntropyLoss() optimizer = Ranger(model.parameters(), lr=args.lr) model, optimizer = amp.initialize(model, optimizer, opt_level=args.apexType) iter = 0 runningLoss = [] for epoch in range(args.epoch): if epoch != 0 and epoch % args.evalFrequency == 0: f1, acc = eval(model, valDataloader, logger) writer.add_scalars('f1_acc', {'f1': f1, 'acc': acc}, iter) if epoch != 0 and epoch % args.saveFrequency == 0: modelName = osp.join(args.subModelDir, 'out_{}.pt'.format(epoch)) # 防止分布式训练保存失败 stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict() torch.save(stateDict, modelName) checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict() } torch.save(checkpoint, modelName) for img, lb, _ in trainDataloader: # array = np.array(img) # for i in range(array.shape[0]): # plt.imshow(array[i, 0, ...], cmap='gray') # plt.show() iter += 1 img, lb = img.to(device), lb.to(device) optimizer.zero_grad() outputs = model(img) loss = criterion(outputs.squeeze(), lb.long()) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() # loss.backward() optimizer.step() runningLoss.append(loss.item()) if iter % args.msgFrequency == 0: avgLoss = np.mean(runningLoss) runningLoss = [] lr = optimizer.param_groups[0]['lr'] logger.info(f'epoch: {epoch} / {args.epoch}, ' f'iter: {iter} / {len(trainDataloader) * args.epoch}, ' f'lr: {lr}, ' f'loss: {avgLoss:.4f}') writer.add_scalar('loss', avgLoss, iter) eval(model, valDataloader, logger) modelName = osp.join(args.subModelDir, 'final.pth') stateDict = model.modules.state_dict() if hasattr(model, 'module') else model.state_dict() torch.save(stateDict, modelName)
def train_model(dataset=dataset, save_dir=save_dir, num_classes=num_classes, lr=lr, num_epochs=nEpochs, save_epoch=snapshot, useTest=useTest, test_interval=nTestInterval): """ Args: num_classes (int): Number of classes in the data num_epochs (int, optional): Number of epochs to train for. """ file = open('run/log.txt', 'w') if modelName == 'C3D': model = C3D(num_class=num_classes) model.my_load_pretrained_weights('saved_model/c3d.pickle') train_params = model.parameters() # train_params = [{'params': get_1x_lr_params(model), 'lr': lr}, # {'params': get_10x_lr_params(model), 'lr': lr * 10}] # elif modelName == 'R2Plus1D': # model = R2Plus1D_model.R2Plus1DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2)) # train_params = [{'params': R2Plus1D_model.get_1x_lr_params(model), 'lr': lr}, # {'params': R2Plus1D_model.get_10x_lr_params(model), 'lr': lr * 10}] # elif modelName == 'R3D': # model = R3D_model.R3DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2)) # train_params = model.parameters() elif modelName == 'Res3D': # model = Resnet(num_classes=num_classes, block=resblock, layers=[3, 4, 6, 3]) # train_params=model.parameters() model = generate_model(50) model = load_pretrained_model(model, './saved_model/r3d50_K_200ep.pth', n_finetune_classes=num_classes) train_params = model.parameters() else: print('We only implemented C3D and R2Plus1D models.') raise NotImplementedError criterion = nn.CrossEntropyLoss( ) # standard crossentropy loss for classification # optimizer = torch.optim.Adam(train_params, lr=lr, betas=(0.9, 0.999), weight_decay=1e-5, # amsgrad=True) optimizer = Ranger(train_params, lr=lr, betas=(.95, 0.999), weight_decay=5e-4) print('use ranger') scheduler = CosineAnnealingLR(optimizer, T_max=32, eta_min=0, last_epoch=-1) # optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4) # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, # gamma=0.1) # the scheduler divides the lr by 10 every 10 epochs if resume_epoch == 0: print("Training {} from scratch...".format(modelName)) else: checkpoint = torch.load(os.path.join( save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'), map_location=lambda storage, loc: storage ) # Load all tensors onto the CPU print("Initializing weights from: {}...".format( os.path.join( save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'))) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['opt_dict']) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # model.to(device) if torch.cuda.is_available(): model = model.cuda() torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True model = nn.DataParallel(model) criterion.cuda() # log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) log_dir = os.path.join(save_dir) writer = SummaryWriter(log_dir=log_dir) print('Training model on {} dataset...'.format(dataset)) train_dataloader = DataLoader(VideoDataset(dataset=dataset, split='train', clip_len=16), batch_size=8, shuffle=True, num_workers=8) val_dataloader = DataLoader(VideoDataset(dataset=dataset, split='validation', clip_len=16), batch_size=8, num_workers=8) test_dataloader = DataLoader(VideoDataset(dataset=dataset, split='test', clip_len=16), batch_size=8, num_workers=8) trainval_loaders = {'train': train_dataloader, 'val': val_dataloader} trainval_sizes = { x: len(trainval_loaders[x].dataset) for x in ['train', 'val'] } test_size = len(test_dataloader.dataset) # my_smooth={'0': 0.88, '1': 0.95, '2': 0.96, '3': 0.79, '4': 0.65, '5': 0.89, '6': 0.88} for epoch in range(resume_epoch, num_epochs): # each epoch has a training and validation step for phase in ['train', 'val']: start_time = timeit.default_timer() # reset the running loss and corrects running_loss = 0.0 running_corrects = 0.0 # set model to train() or eval() mode depending on whether it is trained # or being validated. Primarily affects layers such as BatchNorm or Dropout. if phase == 'train': # scheduler.step() is to be called once every epoch during training # scheduler.step() model.train() else: model.eval() for inputs, labels in tqdm(trainval_loaders[phase]): # move inputs and labels to the device the training is taking place on inputs = Variable(inputs, requires_grad=True).to(device) labels = Variable(labels).to(device) # inputs = inputs.cuda(non_blocking=True) # labels = labels.cuda(non_blocking=True) optimizer.zero_grad() if phase == 'train': outputs = model(inputs) else: with torch.no_grad(): outputs = model(inputs) probs = nn.Softmax(dim=1)(outputs) # the size of output is [bs , 7] preds = torch.max(probs, 1)[1] # preds is the index of maxnum of output # print(outputs) # print(torch.max(outputs, 1)) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() scheduler.step(loss) # for name, parms in model.named_parameters(): # print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \ # ' -->grad_value:', parms.grad) # print('-->name:', name, ' -->grad_value:', parms.grad) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) print('\ntemp/label:{}/{}'.format(preds[0], labels[0])) epoch_loss = running_loss / trainval_sizes[phase] epoch_acc = running_corrects.double() / trainval_sizes[phase] if phase == 'train': writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch) else: writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch) print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format( phase, epoch + 1, nEpochs, epoch_loss, epoch_acc)) stop_time = timeit.default_timer() print("Execution time: " + str(stop_time - start_time) + "\n") file.write("\n[{}] Epoch: {}/{} Loss: {} Acc: {}".format( phase, epoch + 1, nEpochs, epoch_loss, epoch_acc)) if epoch % save_epoch == (save_epoch - 1): torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'opt_dict': optimizer.state_dict(), }, os.path.join(save_dir, saveName + '_epoch-' + str(epoch) + '.pth.tar')) print("Save model at {}\n".format( os.path.join(save_dir, saveName + '_epoch-' + str(epoch) + '.pth.tar'))) if useTest and epoch % test_interval == (test_interval - 1): model.eval() start_time = timeit.default_timer() running_loss = 0.0 running_corrects = 0.0 for inputs, labels in tqdm(test_dataloader): inputs = inputs.to(device) labels = labels.to(device) with torch.no_grad(): outputs = model(inputs) probs = nn.Softmax(dim=1)(outputs) preds = torch.max(probs, 1)[1] loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / test_size epoch_acc = running_corrects.double() / test_size writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch) writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch) print("[test] Epoch: {}/{} Loss: {} Acc: {}".format( epoch + 1, nEpochs, epoch_loss, epoch_acc)) stop_time = timeit.default_timer() print("Execution time: " + str(stop_time - start_time) + "\n") file.write("\n[test] Epoch: {}/{} Loss: {} Acc: {}\n".format( epoch + 1, nEpochs, epoch_loss, epoch_acc)) writer.close() file.close()
def main(): global best_acc, mean, std, scale args = parse_args() args.mean, args.std, args.scale = mean, std, scale args.is_master = args.local_rank == 0 if args.deterministic: cudnn.deterministic = True torch.manual_seed(0) random.seed(0) np.random.seed(0) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.is_master: print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) print(f"Distributed Training Enabled: {args.distributed}") args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() # Scale learning rate based on global batch size # args.lr *= args.batch_size * args.world_size / 256 assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." # create model model = models.ResNet18(args.num_patches, args.num_angles) if args.sync_bn: import apex print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda() optimiser = Ranger(model.parameters(), lr=args.lr) criterion = nn.CrossEntropyLoss().cuda() # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimiser = amp.initialize( model, optimiser, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: model = DDP(model, delay_allreduce=True) # Optionally resume from a checkpoint if args.resume: # Use a local scope to avoid dangling references def resume(): global best_acc if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu)) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] args.poisson_rate = checkpoint["poisson_rate"] model.load_state_dict(checkpoint['state_dict']) optimiser.load_state_dict(checkpoint['optimiser']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) resume() # Data loading code train_dir = os.path.join(args.data, 'train') val_dir = os.path.join(args.data, 'val') crop_size = 225 val_size = 256 imagenet_train = datasets.ImageFolder( root=train_dir, transform=transforms.Compose([ transforms.RandomResizedCrop(crop_size), ])) train_dataset = SSLTrainDataset(imagenet_train, args.num_patches, args.num_angles, args.poisson_rate) imagenet_val = datasets.ImageFolder(root=val_dir, transform=transforms.Compose([ transforms.Resize(val_size), transforms.CenterCrop(crop_size), ])) val_dataset = SSLValDataset(imagenet_val, args.num_patches, args.num_angles) train_sampler = None val_sampler = None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=fast_collate) if args.evaluate: val_loss, val_acc = apex_validate(val_loader, model, criterion, args) utils.logger.info(f"Val Loss = {val_loss}, Val Accuracy = {val_acc}") return # Create dir to save model and command-line args if args.is_master: model_dir = time.ctime().replace(" ", "_").replace(":", "_") model_dir = os.path.join("models", model_dir) os.makedirs(model_dir, exist_ok=True) with open(os.path.join(model_dir, "args.json"), "w") as f: json.dump(args.__dict__, f, indent=2) writer = SummaryWriter() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) # train for one epoch train_loss, train_acc = apex_train(train_loader, model, criterion, optimiser, args, epoch) # evaluate on validation set val_loss, val_acc = apex_validate(val_loader, model, criterion, args) if (epoch + 1) % args.learn_prd == 0: utils.adj_poisson_rate(train_loader, args) # remember best Acc and save checkpoint if args.is_master: is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimiser': optimiser.state_dict(), "poisson_rate": args.poisson_rate }, is_best, model_dir) writer.add_scalars("Loss", { "train_loss": train_loss, "val_loss": val_loss }, epoch) writer.add_scalars("Accuracy", { "train_acc": train_acc, "val_acc": val_acc }, epoch) writer.add_scalar("Poisson_Rate", train_loader.dataset.pdist.rate, epoch)