def train(network_specs, training_params, image_path, save_path, ckpt_path, epoch=10): print('creating datapipe...') # create images DataPipeline datapipe = DataPipeline(image_path=image_path, training_params=training_params) print('creating network model...') # create model VAE model = UNet(network_specs=network_specs, datapipe=datapipe, training_params=training_params) # train the model # save_config is flexible print(''' ============= HERE WE GO ============= ''') model.train(save_path=save_path, ckpt_path=ckpt_path, epoch=epoch)
def lr_find(model: UNet, data_loader, optimizer: Optimizer, criterion, use_gpu, min_lr=0.0001, max_lr=0.1): # Save model and optimizer states to revert model_state = model.state_dict() optimizer_state = optimizer.state_dict() losses = [] lrs = [] scheduler = CyclicExpLR(optimizer, min_lr, max_lr, step_size_up=100, mode='triangular', cycle_momentum=True) model.train() for i, (data, target, class_ids) in enumerate(data_loader): data, target = data, target if use_gpu: data = data.cuda() target = target.cuda() optimizer.zero_grad() output_raw = model(data) # This step is specific for this project output = torch.zeros(output_raw.shape[0], 1, output_raw.shape[2], output_raw.shape[3]) if use_gpu: output = output.cuda() # This step is specific for this project for idx, (raw_o, class_id) in enumerate(zip(output_raw, class_ids)): output[idx] = raw_o[class_id - 1] loss = criterion(output, target) loss.backward() current_lr = optimizer.param_groups[0]['lr'] # Stop if lr stopped increasing if len(lrs) > 0 and current_lr < lrs[-1]: break lrs.append(current_lr) losses.append(loss.item()) optimizer.step() scheduler.step() # Plot in log scale plt.plot(lrs, losses) plt.xscale('log') plt.show() model.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state)
def main(): FLAGS = parser.parse_args() # Loading train and test data. # load_data function exists in file utils.py print("Loading dataset.") X_train, y_train = load_data(FLAGS.dataset_dir + '/train', FLAGS.img_size, FLAGS.augment_data) X_test, y_test = load_data(FLAGS.dataset_dir + '/test', FLAGS.img_size, FLAGS.augment_data) # Making sure that the data was loaded successfully. print("Train set image size : ", X_train.shape) print("Train set label size : ", y_train.shape) print("Test set image size : ", X_test.shape) print("Test set label size : ", y_test.shape) print("Dataset loaded successfully.") # Creating a unet object. # class UNet exists in file model.py unet = UNet(FLAGS.img_size) # Training the network, printing the loss value for every epoch # , and the accuracy on the test set after the training is complete train_loss_values, test_loss_values = unet.train(X_train, y_train, X_test, y_test, FLAGS.num_epochs, FLAGS.learning_rate, FLAGS.model_save_dir) # Plotting loss values on train and test set, and saving it as an image Loss.png # plot_loss exists in file utils.py plot_loss(train_loss_values, test_loss_values, 'Loss.png')
shuffle=True, num_workers=8, pin_memory=False, drop_last=True) val_dataloader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=8, pin_memory=False) trainSize = len(train_dataset) valSize = len(val_dataset) for epoch in range(epoch0, epoch_num): train_loss, val_loss = 0.0, 0.0 G.train(), D.train() for i, batch in enumerate(train_dataloader): for d_iter in range(1): optimizer_D.zero_grad() img, label, onehot = batch[0].to(device), batch[1].to( device), batch[2].to(device) output = G(img) real_imgs = torch.cat((img, onehot), 1) fake_imgs = torch.cat((img, output.detach()), 1) loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) loss_D.backward() optimizer_D.step()
def main(): weights = './weights' logs = './logs' makedirs(weights, logs) #snapshot(logs) device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") images = './kaggle_3m' image_size = 256 scale = 0.05 angle = 15 batch_size = 16 workers = 4 loader_train, loader_valid = data_loaders(images, image_size, scale, angle, batch_size, workers) loaders = {"train": loader_train, "valid": loader_valid} unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels) unet.to(device) dsc_loss = DiceLoss() best_validation_dsc = 0.0 lr = 0.0001 optimizer = optim.Adam(unet.parameters(), lr) logger = Logger(logs) loss_train = [] loss_valid = [] step = 0 epochs = 100 vis_images = 200 vis_freq = 10 for epoch in tqdm(range(epochs), total=epochs): for phase in ["train", "valid"]: if phase == "train": unet.train() else: unet.eval() validation_pred = [] validation_true = [] for i, data in enumerate(loaders[phase]): if phase == "train": step += 1 x, y_true = data x, y_true = x.to(device), y_true.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): y_pred = unet(x) loss = dsc_loss(y_pred, y_true) if phase == "valid": loss_valid.append(loss.item()) y_pred_np = y_pred.detach().cpu().numpy() validation_pred.extend( [y_pred_np[s] for s in range(y_pred_np.shape[0])]) y_true_np = y_true.detach().cpu().numpy() validation_true.extend( [y_true_np[s] for s in range(y_true_np.shape[0])]) if (epoch % vis_freq == 0) or (epoch == epochs - 1): if i * batch_size < vis_images: tag = "image/{}".format(i) num_images = vis_images - i * batch_size logger.image_list_summary( tag, log_images(x, y_true, y_pred)[:num_images], step, ) if phase == "train": loss_train.append(loss.item()) loss.backward() optimizer.step() if phase == "train" and (step + 1) % 10 == 0: log_loss_summary(logger, loss_train, step) loss_train = [] if phase == "valid": log_loss_summary(logger, loss_valid, step, prefix="val_") mean_dsc = np.mean( dsc_per_volume( validation_pred, validation_true, loader_valid.dataset.patient_slice_index, )) logger.scalar_summary("val_dsc", mean_dsc, step) if mean_dsc > best_validation_dsc: best_validation_dsc = mean_dsc torch.save(unet.state_dict(), os.path.join(weights, "unet.pt")) loss_valid = [] print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
else: loss = criterion(output, target) out = output.cpu().detach().numpy() out = np.array(out[0] > 0.5) dice_val += dice(out, target_pt[0]) test_loss += loss.item() * data.size(0) dice_val = dice_val / numpts test_loss = test_loss / numpts print("Dice: " + str(dice_val)) print("Loss: " + str(test_loss)) # Training model.train() np.random.shuffle(indices) all_data = all_data[indices] labels = labels[indices] train_loss = 0.0 numpts = len(all_data) for i in range(0, numpts, batch_size): data = all_data[i:i+batch_size] target = labels[i:i+batch_size] if LOSS_NUM == 3 or LOSS_NUM == 4: wts = weights[i:i+batch_size] # Doesn't improve performance if AUG > 0:
class Instructor: ''' Model training and evaluation ''' def __init__(self, opt): self.opt = opt if opt.inference: self.testset = TestImageDataset(fdir=opt.impaths['test'], imsize=opt.imsize) else: self.trainset = ImageDataset(fdir=opt.impaths['train'], bdir=opt.impaths['btrain'], imsize=opt.imsize, mode='train', aug_prob=opt.aug_prob, prefetch=opt.prefetch) self.valset = ImageDataset(fdir=opt.impaths['val'], bdir=opt.impaths['bval'], imsize=opt.imsize, mode='val', aug_prob=opt.aug_prob, prefetch=opt.prefetch) self.model = UNet(n_channels=3, n_classes=1, bilinear=self.opt.use_bilinear) if opt.checkpoint: self.model.load_state_dict( torch.load('./state_dict/{:s}'.format(opt.checkpoint), map_location=self.opt.device)) print('checkpoint {:s} has been loaded'.format(opt.checkpoint)) if opt.multi_gpu == 'on': self.model = torch.nn.DataParallel(self.model) self.model = self.model.to(opt.device) self._print_args() def _print_args(self): n_trainable_params, n_nontrainable_params = 0, 0 for p in self.model.parameters(): n_params = torch.prod(torch.tensor(p.shape)) if p.requires_grad: n_trainable_params += n_params else: n_nontrainable_params += n_params self.info = 'n_trainable_params: {0}, n_nontrainable_params: {1}\n'.format( n_trainable_params, n_nontrainable_params) self.info += 'training arguments:\n' + '\n'.join([ '>>> {0}: {1}'.format(arg, getattr(self.opt, arg)) for arg in vars(self.opt) ]) if self.opt.device.type == 'cuda': print('cuda memory allocated:', torch.cuda.memory_allocated(opt.device.index)) print(self.info) def _reset_records(self): self.records = { 'best_epoch': 0, 'best_dice': 0, 'train_loss': list(), 'val_loss': list(), 'val_dice': list(), 'checkpoints': list() } def _update_records(self, epoch, train_loss, val_loss, val_dice): if val_dice > self.records['best_dice']: path = './state_dict/{:s}_dice{:.4f}_temp{:s}.pt'.format( self.opt.model_name, val_dice, str(time.time())[-6:]) if self.opt.multi_gpu == 'on': torch.save(self.model.module.state_dict(), path) else: torch.save(self.model.state_dict(), path) self.records['best_epoch'] = epoch self.records['best_dice'] = val_dice self.records['checkpoints'].append(path) self.records['train_loss'].append(train_loss) self.records['val_loss'].append(val_loss) self.records['val_dice'].append(val_dice) def _draw_records(self): timestamp = str(int(time.time())) print('best epoch: {:d}'.format(self.records['best_epoch'])) print('best train loss: {:.4f}, best val loss: {:.4f}'.format( min(self.records['train_loss']), min(self.records['val_loss']))) print('best val dice {:.4f}'.format(self.records['best_dice'])) os.rename( self.records['checkpoints'][-1], './state_dict/{:s}_dice{:.4f}_save{:s}.pt'.format( self.opt.model_name, self.records['best_dice'], timestamp)) for path in self.records['checkpoints'][0:-1]: os.remove(path) # Draw figures plt.figure() trainloss, = plt.plot(self.records['train_loss']) valloss, = plt.plot(self.records['val_loss']) plt.legend([trainloss, valloss], ['train', 'val'], loc='upper right') plt.title('{:s} loss curve'.format(timestamp)) plt.savefig('./figs/{:s}_loss.png'.format(timestamp), format='png', transparent=True, dpi=300) plt.figure() valdice, = plt.plot(self.records['val_dice']) plt.title('{:s} dice curve'.format(timestamp)) plt.savefig('./figs/{:s}_dice.png'.format(timestamp), format='png', transparent=True, dpi=300) # Save report report = '\t'.join( ['val_dice', 'train_loss', 'val_loss', 'best_epoch', 'timestamp']) report += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:d}\t{:s}\n{:s}".format( self.records['best_dice'], min(self.records['train_loss']), min(self.records['val_loss']), self.records['best_epoch'], timestamp, self.info) with open('./logs/{:s}_log.txt'.format(timestamp), 'w') as f: f.write(report) print('report saved:', './logs/{:s}_log.txt'.format(timestamp)) def _train(self, train_dataloader, criterion, optimizer): self.model.train() train_loss, n_total, n_batch = 0, 0, len(train_dataloader) for i_batch, sample_batched in enumerate(train_dataloader): inputs, target = sample_batched[0].to( self.opt.device), sample_batched[1].to(self.opt.device) predict = self.model(inputs) optimizer.zero_grad() loss = criterion(predict, target) loss.backward() optimizer.step() train_loss += loss.item() * len(sample_batched) n_total += len(sample_batched) ratio = int((i_batch + 1) * 50 / n_batch) sys.stdout.write("\r[" + ">" * ratio + " " * (50 - ratio) + "] {}/{} {:.2f}%".format(i_batch + 1, n_batch, (i_batch + 1) * 100 / n_batch)) sys.stdout.flush() print() return train_loss / n_total def _evaluation(self, val_dataloader, criterion): self.model.eval() val_loss, val_dice, n_total = 0, 0, 0 with torch.no_grad(): for sample_batched in val_dataloader: inputs, target = sample_batched[0].to( self.opt.device), sample_batched[1].to(self.opt.device) predict = self.model(inputs) loss = criterion(predict, target) dice = dice_coeff(predict, target) val_loss += loss.item() * len(sample_batched) val_dice += dice.item() * len(sample_batched) n_total += len(sample_batched) return val_loss / n_total, val_dice / n_total def run(self): _params = filter(lambda p: p.requires_grad, self.model.parameters()) optimizer = torch.optim.Adam(_params, lr=self.opt.lr, weight_decay=self.opt.l2reg) criterion = BCELoss2d() train_dataloader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True) val_dataloader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False) self._reset_records() for epoch in range(self.opt.num_epoch): train_loss = self._train(train_dataloader, criterion, optimizer) val_loss, val_dice = self._evaluation(val_dataloader, criterion) self._update_records(epoch, train_loss, val_loss, val_dice) print( '{:d}/{:d} > train loss: {:.4f}, val loss: {:.4f}, val dice: {:.4f}' .format(epoch + 1, self.opt.num_epoch, train_loss, val_loss, val_dice)) self._draw_records() def inference(self): test_dataloader = DataLoader(dataset=self.testset, batch_size=1, shuffle=False) n_batch = len(test_dataloader) with torch.no_grad(): for i_batch, sample_batched in enumerate(test_dataloader): index, inputs = sample_batched[0], sample_batched[1].to( self.opt.device) predict = self.model(inputs) self.testset.save_img(index.item(), predict, self.opt.use_crf) ratio = int((i_batch + 1) * 50 / n_batch) sys.stdout.write( "\r[" + ">" * ratio + " " * (50 - ratio) + "] {}/{} {:.2f}%".format(i_batch + 1, n_batch, (i_batch + 1) * 100 / n_batch)) sys.stdout.flush() print()
def train(n_channels, n_classes, bilinear, epochs, batch_size, lr, val_rate, num_workers, pin_memory, roots, threshold): data_root, model_root, log_root = roots device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(torch.cuda.get_device_properties(device)) print("id:{}".format(device)) print("name:{}".format(torch.cuda.get_device_name(0))) logging.info(f'Using device {device}') model = UNet(n_channels, n_classes, bilinear).to(device) logging.info(f'Network:\n' f'\t{n_channels} input channels\n' f'\t{n_classes} output channels (classes)\n' f'\t{"Bilinear" if bilinear else "Dilated conv"} upscaling') dataset = SuperviselyDataset(data_root) num_val = int(len(dataset) * val_rate) num_train = len(dataset) - num_val train_date, val_data = random_split(dataset, [num_train, num_val]) train_loader = DataLoader(train_date, batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) val_loader = DataLoader(val_data, batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}') logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate (original): {lr} Training item: {num_train} Validation item: {num_val} Device: {device.type}''') # criterion = nn.CrossEntropyLoss() if n_classes > 1 else nn.BCEWithLogitsLoss() criterion = nn.MSELoss() optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if n_classes > 1 else 'max', patience=2) losses = [] global_step = 0 for epoch in range(epochs): model.train() with tqdm(total=num_train) as t: t.set_description('epoch: {}/{}'.format(epoch + 1, epochs)) for img, mask in train_loader: # import cv2 # print(mask.shape) # cv2.imshow('figure1', img[0].permute(1, 2, 0).numpy()) # cv2.waitKey() # cv2.imshow('figure2', mask[0].permute(1, 2, 0).numpy()) # cv2.waitKey() mask_type = torch.float32 if n_classes == 1 else torch.long mask = mask.to(device=device, dtype=mask_type) img = img.to(device=device, dtype=torch.float32) # update pred = model(img) loss = criterion(pred, mask) optimizer.zero_grad() nn.utils.clip_grad_value_(model.parameters(), 0.1) loss.backward() optimizer.step() losses.append(loss.item()) # writer.add_scalar('Loss/train', loss.item(), global_step) t.set_postfix(loss='{:.6f}'.format(loss), lr='%.8f' % optimizer.param_groups[0]['lr']) t.update(img.shape[0]) # value # global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0: # score = evaluate(model, val_loader, n_classes, device, num_val) # model.train() # logging.info('Validation Dice Coeff: {}'.format(score)) # scheduler.step(score) # if n_classes > 1: # logging.info('Validation cross entropy: {}'.format(score)) # writer.add_scalar('Loss/test', score, global_step) # else: # logging.info('Validation Dice Coeff: {}'.format(score)) # writer.add_scalar('Dice/test', score, global_step) # # writer.add_images('images', img, global_step) # if n_classes == 1: # writer.add_images('masks/true', mask, global_step) # writer.add_images('masks/pred', torch.sigmoid(pred) > threshold, global_step) save_model_and_loss(model, model_root, losses, log_root) writer.close()
def train(args): """ Train UNet from datasets """ # dataset print('Reading dataset from {}...'.format(args.dataset_path)) train_dataset = SSDataset(dataset_path=args.dataset_path, is_train=True) val_dataset = SSDataset(dataset_path=args.dataset_path, is_train=False) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False) # mask with open(args.mask_json_path, 'w', encoding='utf-8') as mask: colors = SSDataset.all_colors mask.write(json.dumps(colors)) print('Mask colors list has been saved in {}'.format( args.mask_json_path)) # model net = UNet(in_channels=3, out_channels=5) if args.cuda: net = net.cuda() # setting lr = args.lr # 1e-3 optimizer = optim.Adam(net.parameters(), lr=lr) criterion = loss_fn # run train_losses = [] val_losses = [] print('Start training...') for epoch_idx in range(args.epochs): # train net.train() train_loss = 0 for batch_idx, batch_data in enumerate(train_dataloader): xs, ys = batch_data if args.cuda: xs = xs.cuda() ys = ys.cuda() ys_pred = net(xs) loss = criterion(ys_pred, ys) train_loss += loss optimizer.zero_grad() loss.backward() optimizer.step() # val net.eval() val_loss = 0 for batch_idx, batch_data in enumerate(val_dataloader): xs, ys = batch_data if args.cuda: xs = xs.cuda() ys = ys.cuda() ys_pred = net(xs) loss = loss_fn(ys_pred, ys) val_loss += loss train_losses.append(train_loss) val_losses.append(val_loss) print('Epoch: {}, Train total loss: {}, Val total loss: {}'.format( epoch_idx + 1, train_loss.item(), val_loss.item())) # save if (epoch_idx + 1) % args.save_epoch == 0: checkpoint_path = os.path.join( args.checkpoint_path, 'checkpoint_{}.pth'.format(epoch_idx + 1)) torch.save(net.state_dict(), checkpoint_path) print('Saved Checkpoint at Epoch {} to {}'.format( epoch_idx + 1, checkpoint_path)) # summary if args.do_save_summary: epoch_range = list(range(1, args.epochs + 1)) plt.plot(epoch_range, train_losses, 'r', label='Train loss') plt.plot(epoch_range, val_loss, 'g', label='Val loss') plt.imsave(args.summary_image) print('Summary images have been saved in {}'.format( args.summary_image)) # save net.eval() torch.save(net.state_dict(), args.model_state_dict) print('Saved state_dict in {}'.format(args.model_state_dict))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--bs', metavar='bs', type=int, default=2) parser.add_argument('--path', type=str, default='../../data') parser.add_argument('--results', type=str, default='../../results/model') parser.add_argument('--nw', type=int, default=0) parser.add_argument('--max_images', type=int, default=None) parser.add_argument('--val_size', type=int, default=None) parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--lr_decay', type=float, default=0.99997) parser.add_argument('--kernel_lvl', type=float, default=1) parser.add_argument('--noise_lvl', type=float, default=1) parser.add_argument('--motion_blur', type=bool, default=False) parser.add_argument('--homo_align', type=bool, default=False) parser.add_argument('--resume', type=bool, default=False) args = parser.parse_args() print() print(args) print() if not os.path.isdir(args.results): os.makedirs(args.results) PATH = args.results if not args.resume: f = open(PATH + "/param.txt", "a+") f.write(str(args)) f.close() writer = SummaryWriter(PATH + '/runs') # CUDA for PyTorch use_cuda = torch.cuda.is_available() device = torch.device('cuda:0' if use_cuda else "cpu") # Parameters params = {'batch_size': args.bs, 'shuffle': True, 'num_workers': args.nw} # Generators print('Initializing training set') training_set = Dataset(args.path + '/train/', args.max_images, args.kernel_lvl, args.noise_lvl, args.motion_blur, args.homo_align) training_generator = data.DataLoader(training_set, **params) print('Initializing validation set') validation_set = Dataset(args.path + '/test/', args.val_size, args.kernel_lvl, args.noise_lvl, args.motion_blur, args.homo_align) validation_generator = data.DataLoader(validation_set, **params) # Model model = UNet(in_channel=3, out_channel=3) if args.resume: models_path = get_newest_model(PATH) print('loading model from ', models_path) model.load_state_dict(torch.load(models_path)) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = torch.nn.DataParallel(model) model.to(device) # Loss + optimizer criterion = BurstLoss() optimizer = RAdam(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=8 // args.bs, gamma=args.lr_decay) if args.resume: n_iter = np.loadtxt(PATH + '/train.txt', delimiter=',')[:, 0][-1] else: n_iter = 0 # Loop over epochs for epoch in range(args.epochs): train_loss = 0.0 # Training model.train() for i, (X_batch, y_labels) in enumerate(training_generator): # Alter the burst length for each mini batch burst_length = np.random.randint(2, 9) X_batch = X_batch[:, :burst_length, :, :, :] # Transfer to GPU X_batch, y_labels = X_batch.to(device).type( torch.float), y_labels.to(device).type(torch.float) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize pred = model(X_batch) loss = criterion(pred, y_labels) loss.backward() optimizer.step() scheduler.step() train_loss += loss.detach().cpu().numpy() writer.add_scalar('training_loss', loss.item(), n_iter) if i % 100 == 0 and i > 0: loss_printable = str(np.round(train_loss, 2)) f = open(PATH + "/train.txt", "a+") f.write(str(n_iter) + "," + loss_printable + "\n") f.close() print("training loss ", loss_printable) train_loss = 0.0 if i % 1000 == 0: if torch.cuda.device_count() > 1: torch.save( model.module.state_dict(), os.path.join(PATH, 'model_' + str(int(n_iter)) + '.pt')) else: torch.save( model.state_dict(), os.path.join(PATH, 'model_' + str(int(n_iter)) + '.pt')) if i % 1000 == 0: # Validation val_loss = 0.0 with torch.set_grad_enabled(False): model.eval() for v, (X_batch, y_labels) in enumerate(validation_generator): # Alter the burst length for each mini batch burst_length = np.random.randint(2, 9) X_batch = X_batch[:, :burst_length, :, :, :] # Transfer to GPU X_batch, y_labels = X_batch.to(device).type( torch.float), y_labels.to(device).type(torch.float) # forward + backward + optimize pred = model(X_batch) loss = criterion(pred, y_labels) val_loss += loss.detach().cpu().numpy() if v < 5: im = make_im(pred, X_batch, y_labels) writer.add_image('image_' + str(v), im, n_iter) writer.add_scalar('validation_loss', val_loss, n_iter) loss_printable = str(np.round(val_loss, 2)) print('validation loss ', loss_printable) f = open(PATH + "/eval.txt", "a+") f.write(str(n_iter) + "," + loss_printable + "\n") f.close() n_iter += args.bs
def train(args): dataset = open("dataset.csv", "r").readlines() train_set = dataset[:600] val_set = dataset[600:] root_dir = root_dir = "data/Lung_Segmentation/" train_data = LungSegmentationDataGen(train_set, root_dir, args) val_data = LungSegmentationDataGen(val_set, root_dir, args) train_dataloader = DataLoader(train_data, batch_size=5, shuffle=True, num_workers=4) val_dataloader = DataLoader(val_data, batch_size=5, shuffle=True, num_workers=4) dataloaders = {"train": train_dataloader, "val": val_dataloader} dataset_sizes = {"train": len(train_set), "val": len(val_set)} print("dataset_sizes: {}".format(dataset_sizes)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = UNet(in_channels=1) model = model.to(device) optimizer = optim.Adam(model.parameters()) loss_train = [] loss_valid = [] current_mean_dsc = 0.0 best_validation_dsc = 0.0 epochs = args.epochs for epoch in range(epochs): print('Epoch {}/{}'.format(epoch, epochs - 1)) print('-' * 10) dice_score_list = [] for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode # Iterate over data. for i, data in enumerate(dataloaders[phase]): inputs, y_true = data inputs = inputs.to(device) y_true = y_true.to(device) # zero the parameter gradients optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): # forward pass with batch input y_pred = model(inputs) loss = dice_loss(y_true, y_pred) # backward + optimize only if in training phase if phase == 'train': # print("step: {}, train_loss: {}".format(i, loss)) loss_train.append(loss.item()) # calculate the gradients based on loss loss.backward() # update the weights optimizer.step() if phase == "val": loss_valid.append(loss.item()) dsc = dice_score(y_true, y_pred) print("step: {}, val_loss: {}, val dice_score: {}". format(i, loss, dsc)) dice_score_list.append(dsc.detach().numpy()) if phase == "train" and (i + 1) % 10 == 0: print("step:{}, train_loss: {}".format( i + 1, np.mean(loss_train))) loss_train = [] if phase == "val": print("mean val_loss: {}".format(np.mean(loss_valid))) loss_valid = [] current_mean_dsc = np.mean(dice_score_list) print("validation set dice_score: {}".format(current_mean_dsc)) if current_mean_dsc > best_validation_dsc: best_validation_dsc = current_mean_dsc print("best dice_score on val set: {}".format( best_validation_dsc)) model_name = "unet_{0:.2f}.pt".format(best_validation_dsc) torch.save(model.state_dict(), os.path.join(args.weights, model_name))
def main(_): # Make checkpoint directory if not os.path.exists(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) # Create a dataset object if label_type == 'one_hot': data=utils.DataOneHot(debug=args.debug, patch_overlap=args.patch_overlap, im_size=args.im_size, \ band_n=args.band_n, t_len=args.t_len, path=args.path, class_n=args.class_n, pc_mode=args.pc_mode, \ test_n_limit=args.test_n_limit,memory_mode=args.memory_mode, \ balance_samples_per_class=args.balance_samples_per_class, test_get_stride=args.test_get_stride, \ n_apriori=args.n_apriori,patch_length=args.patch_len,squeeze_classes=args.squeeze_classes,im_h=args.im_h,im_w=args.im_w, \ id_first=args.id_first, train_test_mask_name=args.train_test_mask_name, \ test_overlap_full=args.test_overlap_full,ram_store=args.ram_store,patches_save=args.patches_save) elif label_type == 'semantic': data=utils.DataSemantic(debug=args.debug, patch_overlap=args.patch_overlap, im_size=args.im_size, \ band_n=args.band_n, t_len=args.t_len, path=args.path, class_n=args.class_n, pc_mode=args.pc_mode, \ test_n_limit=args.test_n_limit,memory_mode=args.memory_mode, \ balance_samples_per_class=args.balance_samples_per_class, test_get_stride=args.test_get_stride, \ n_apriori=args.n_apriori,patch_length=args.patch_len,squeeze_classes=args.squeeze_classes,im_h=args.im_h,im_w=args.im_w, \ id_first=args.id_first, train_test_mask_name=args.train_test_mask_name, \ test_overlap_full=args.test_overlap_full,ram_store=args.ram_store,patches_save=args.patches_save) # Load images and create dataset (Extract patches) if args.memory_mode == "ram": data.create() deb.prints(data.ram_data["train"]["ims"].shape) # Run tensorflow session with tf.Session() as sess: # Create a neural network object (Define model graph) if args.model == 'convlstm': model = conv_lstm(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'conv3d': model = Conv3DMultitemp(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'unet': model = UNet(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'smcnn': model = SMCNN(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'smcnnlstm': model = SMCNNlstm(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'smcnn_unet': model = SMCNN_UNet(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'smcnn_conv3d': model = SMCNN_conv3d(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'lstm': model = lstm(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'convlstm_semantic': model = conv_lstm_semantic(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) elif args.model == 'smcnn_semantic': model = SMCNN_semantic(sess, batch_size=args.batch_size, epoch=args.epoch, train_size=args.train_size, timesteps=args.timesteps, patch_len=args.patch_len, kernel=args.kernel, channels=args.channels, filters=args.filters, n_classes=args.n_classes, checkpoint_dir=args.checkpoint_dir, log_dir=args.log_dir, data=data.ram_data, conf=data.conf, debug=args.debug) if args.phase == 'train': # Train only once model.train(args) elif args.phase == 'repeat': # Train for a specific number of repetitions model.train_repeat(args) elif args.phase == 'test': # Test best model from experiment repetitions model.test(args)
class FenceSegmentationNet(): def __init__(self, filePathTrain): # Hyperparameters self.batchSize = 1 self.numEpochs = 10 self.learningRate = 0.001 self.validPercent = 0.1 self.trainShuffle = True self.testShuffle = False self.momentum = 0.99 self.imageDim = 128 # Variables self.imageDirectory = filePathTrain self.labelDirectory = filePathTrain self.numChannels = 3 self.numClasses = 1 # Device configuration self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # Load dataset self.trainLoader = self.getTrainingLoader() #self.testLoader = self.getTestLoader() # Setup model self.model = UNet(n_channels=self.numChannels, n_classes=self.numClasses, bilinear=True).to(self.device) #self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.learningRate, weight_decay=self.weightDecay, momentum=self.momentum) #self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learningRate) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learningRate, momentum=self.momentum) #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min' if self.numClasses > 1 else 'max', patience=2) self.criterion = DiceLoss() def loadImages(self, directory): fileNames = os.listdir(directory + 'images/') images = [] labels = [] for fileName in tqdm(fileNames, desc='Loading data'): image = Image.open(directory + 'images/' + fileName) label = Image.open(directory + 'labels/' + fileName).convert('L') label = label.point(lambda x: 0 if x < 128 else 255, '1') # Converting the data into the format [Channels Width Height] label = np.asarray(label).reshape(1, self.imageDim, self.imageDim) image = np.asarray(image).reshape(3, self.imageDim, self.imageDim) labels.append(label) images.append(image) return images, labels def getTrainingLoader(self): print('Get trainings loader:') data, labels = self.loadImages(self.imageDirectory) # Convert image data and labels to tensor format tensorData = torch.Tensor(data) tensorLabels = torch.Tensor(labels) # Convert the two tensors into one tensor dataset tensorDataset = TensorDataset(tensorData, tensorLabels) # Convert the tensor dataset into a dataloader with format [Batch Channels Width Height] trainLoader = DataLoader(tensorDataset, batch_size=self.batchSize, shuffle=self.trainShuffle) return trainLoader def getTestLoader(self): print('Get test loader:') data, labels = self.loadImages(self.filePathTest) # Convert image data and labels to tensor format tensorData = torch.Tensor(data) tensorLabels = torch.Tensor(labels) # Convert the two tensors into one tensor dataset tensorDataset = TensorDataset(tensorData, tensorLabels) # Convert the tensor dataset into a dataloader with format [Batch Channels Width Height] testLoader = DataLoader(tensorDataset, batch_size=self.batchSize, shuffle=self.testShuffle) return testLoader def trainModel(self, patience=3): print('Train Model:') self.model.train() losses = [] n_total_steps = len(self.trainLoader) for epoch in range(self.numEpochs): n_correct, epoch_loss = 0, 0 tmpStr = 'Epoch [{:>3}/{:>3}]'.format(epoch + 1, self.numEpochs) for (samples, labels) in tqdm(self.trainLoader, desc=tmpStr): samples = samples.to(self.device, dtype=torch.float32) labels = labels.to(self.device, dtype=torch.long) # Forward pass outputs = self.model(samples) loss = self.criterion(outputs, labels) epoch_loss += loss.item() # Backward and optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() tot_loss = epoch_loss / n_total_steps print(tot_loss) losses.append(tot_loss) return losses def testModel(self): print('Testing Model:') self.model.eval() n_correct = 0 n_samples = 0 tmpStr = 'Testing' for i, (samples, labels) in enumerate(tqdm(self.testLoader, desc=tmpStr)): samples = samples.to(self.device, dtype=torch.float32) labels = labels.to(self.device, dtype=torch.float32) outputs = self.model(samples) _, predicted = torch.max(outputs, 1) n_samples += labels.size(0) pred = torch.sigmoid(outputs) pred = (pred > 0.2).float() n_correct += dice_coeff(pred, labels).item() acc = 100.0 * n_correct / n_samples return acc def plot_history(self, train_losses): print('Plotting epoch history') plt.figure(figsize=(3, 3)) plt.xlabel('epoch') plt.ylabel('loss') plt.plot(train_losses, label='train') plt.ylim(0, 1) plt.legend() plt.grid() plt.tight_layout() plt.savefig('plot.png') def saveModel(self, fileName): torch.save(self.model.state_dict(), fileName) def loadModel(self, fileName): self.model.load_state_dict(torch.load(fileName))
# Create context inputs (sequences of 3 slices) print('Get contexts inputs ...') ct_seqs = [] for idx, ct in enumerate(imgs): if (idx != 0) and idx != (len(imgs) - 1): spatial_context = np.stack((imgs[idx - 1], imgs[idx], imgs[idx + 1])) ct_seqs.append(spatial_context) ct_seqs_np = np.array(ct_seqs) # Remove first and last output segmentation mask cut_png_imgs = np.delete(png_imgs, [0, len(png_imgs) - 1], 0) # axis = 0 # Initializing u-net and train the network fcn = UNet() fcn.train(ct_seqs_np, cut_png_imgs) ##### TRAIN CT # fcn.train(imgs_mri, png_imgs, (256,256,1)) ##### TRAIN MRI # # TESTING # test_data_path = ['./CT_data_batch1/10/DICOM_anon/i0081,0000b.dcm'] # test2 = read_dicoms(test_data_path) # test1 = get_pixels_hu(test2) # hi = fcn.predict(test1) # # OUTPUT FILE plot # plt.imshow(np.squeeze(hi[0], axis=2)) # plt.show() print('OK')
def train(): if not os.path.exists('train_model/'): os.makedirs('train_model/') if not os.path.exists('result/'): os.makedirs('result/') train_data, dev_data, word2id, id2word, char2id, opts = load_data( vars(args)) model = UNet(opts) if args.use_cuda: model = model.cuda() dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) if args.eval: print('load model...') model.load_state_dict(torch.load(args.model_dir)) model.eval() model.Evaluate(dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') exit() if args.load_model: print('load model...') model.load_state_dict(torch.load(args.model_dir)) model.eval() _, F1 = model.Evaluate(dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') best_score = F1 with open(args.model_dir + '_f1_scores.pkl', 'rb') as f: f1_scores = pkl.load(f) with open(args.model_dir + '_em_scores.pkl', 'rb') as f: em_scores = pkl.load(f) else: best_score = 0.0 f1_scores = [] em_scores = [] parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adamax(parameters, lr=args.lrate) lrate = args.lrate for epoch in range(1, args.epochs + 1): train_batches = get_batches(train_data, args.batch_size) dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) total_size = len(train_data) // args.batch_size model.train() for i, train_batch in enumerate(train_batches): loss = model(train_batch) model.zero_grad() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(parameters, opts['grad_clipping']) optimizer.step() model.reset_parameters() if i % 100 == 0: print( 'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f' % (epoch, i, total_size, model.train_loss.value, lrate, best_score)) sys.stdout.flush() model.eval() exact_match_score, F1 = model.Evaluate( dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') f1_scores.append(F1) em_scores.append(exact_match_score) with open(args.model_dir + '_f1_scores.pkl', 'wb') as f: pkl.dump(f1_scores, f) with open(args.model_dir + '_em_scores.pkl', 'wb') as f: pkl.dump(em_scores, f) if best_score < F1: best_score = F1 print('saving %s ...' % args.model_dir) torch.save(model.state_dict(), args.model_dir) if epoch > 0 and epoch % args.decay_period == 0: lrate *= args.decay for param_group in optimizer.param_groups: param_group['lr'] = lrate
from model import UNet from config import UnetConfig config = UnetConfig() unet = UNet(config=config) unet.train("dataset-postdam", "logs")
class Trainer(): def __init__(self,config,trainLoader,validLoader): self.config = config self.trainLoader = trainLoader self.validLoader = validLoader self.numTrain = len(self.trainLoader.dataset) self.numValid = len(self.validLoader.dataset) self.saveModelDir = str(self.config.save_model_dir)+"/" self.bestModel = config.bestModel self.useGpu = self.config.use_gpu self.net = UNet() if(self.config.resume == True): print("LOADING SAVED MODEL") self.loadCheckpoint() else: print("INTIALIZING NEW MODEL") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.net = self.net.to(self.device) self.totalEpochs = config.epochs self.optimizer = optim.Adam(self.net.parameters(), lr=5e-4) self.loss = DiceLoss() self.num_params = sum([p.data.nelement() for p in self.net.parameters()]) self.trainPaitence = config.train_paitence if not self.config.resume: # self.freezeLayers(6) summary(self.net, input_size=(3,256,256)) print('[*] Number of model parameters: {:,}'.format(self.num_params)) self.writer = SummaryWriter(self.config.tensorboard_path+"/") def train(self): bestIOU = 0 print("\n[*] Train on {} sample pairs, validate on {} trials".format( self.numTrain, self.numValid)) for epoch in range(0,self.totalEpochs): print('\nEpoch: {}/{}'.format(epoch+1, self.totalEpochs)) self.trainOneEpoch(epoch) validationIOU = self.validationTest(epoch) print("VALIDATION IOU: ",validationIOU) # check for improvement if(validationIOU > bestIOU): print("COUNT RESET !!!") bestIOU=validationIOU self.counter = 0 self.saveCheckPoint( { 'epoch': epoch + 1, 'model_state': self.net.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': bestIOU, },True) else: self.counter += 1 if self.counter > self.trainPaitence: self.saveCheckPoint( { 'epoch': epoch + 1, 'model_state': self.net.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': validationIOU, },False) print("[!] No improvement in a while, stopping training...") print("BEST VALIDATION IOU: ",bestIOU) return None def trainOneEpoch(self,epoch): self.net.train() train_loss = 0 total_IOU = 0 for batch_idx, (images,targets) in enumerate(self.trainLoader): images = images.to(self.device) targets = targets.to(self.device) self.optimizer.zero_grad() outputMaps = self.net(images) loss = self.loss(outputMaps,targets) loss.backward() self.optimizer.step() train_loss += loss.item() current_IOU = calc_IOU(outputMaps,targets) total_IOU += current_IOU del(images) del(targets) progress_bar(batch_idx, len(self.trainLoader), 'Loss: %.3f | IOU: %.3f' % (train_loss/(batch_idx+1), current_IOU)) self.writer.add_scalar('Train/Loss', train_loss/batch_idx+1, epoch) self.writer.add_scalar('Train/IOU', total_IOU/batch_idx+1, epoch) def validationTest(self,epoch): self.net.eval() validationLoss = [] total_IOU = [] with torch.no_grad(): for batch_idx, (images,targets) in enumerate(self.validLoader): images = images.to(self.device) targets = targets.to(self.device) outputMaps = self.net(images) loss = self.loss(outputMaps,targets) currentValidationLoss = loss.item() validationLoss.append(currentValidationLoss) current_IOU = calc_IOU(outputMaps,targets) total_IOU.append(current_IOU) # progress_bar(batch_idx, len(self.validLoader), 'Loss: %.3f | IOU: %.3f' % (currentValidationLoss), current_IOU) del(images) del(targets) meanIOU = np.mean(total_IOU) meanValidationLoss = np.mean(validationLoss) self.writer.add_scalar('Validation/Loss', meanValidationLoss, epoch) self.writer.add_scalar('Validation/IOU', meanIOU, epoch) print("VALIDATION LOSS: ",meanValidationLoss) return meanIOU def test(self,dataLoader): self.net.eval() testLoss = [] total_IOU = [] total_outputs_maps = [] total_input_images = [] with torch.no_grad(): for batch_idx, (images,targets) in enumerate(dataLoader): images = images.to(self.device) targets = targets.to(self.device) outputMaps = self.net(images) loss = self.loss(outputMaps,targets) testLoss.append(loss.item()) current_IOU = calc_IOU(outputMaps,targets) total_IOU.append(current_IOU) total_outputs_maps.append(outputMaps.cpu().detach().numpy()) # total_input_images.append(transforms.ToPILImage()(images)) total_input_images.append(images.cpu().detach().numpy()) del(images) del(targets) break meanIOU = np.mean(total_IOU) meanLoss = np.mean(testLoss) print("TEST IOU: ",meanIOU) print("TEST LOSS: ",meanLoss) return total_input_images,total_outputs_maps def saveCheckPoint(self,state,isBest): filename = "model.pth" ckpt_path = os.path.join(self.saveModelDir, filename) torch.save(state, ckpt_path) if isBest: filename = "best_model.pth" shutil.copyfile(ckpt_path, os.path.join(self.saveModelDir, filename)) def loadCheckpoint(self): print("[*] Loading model from {}".format(self.saveModelDir)) if(self.bestModel): print("LOADING BEST MODEL") filename = "best_model.pth" else: filename = "model.pth" ckpt_path = os.path.join(self.saveModelDir, filename) print(ckpt_path) if(self.useGpu==False): self.net=torch.load(ckpt_path, map_location=lambda storage, loc: storage) else: print("*"*40+" LOADING MODEL FROM GPU "+"*"*40) self.ckpt = torch.load(ckpt_path) self.net.load_state_dict(self.ckpt['model_state']) self.net.cuda()
# Returns the number of GPUs available. torch.cuda.device_count() # 1 # Gets the name of a device. print(torch.cuda.get_device_name(0)) torch.cuda.device(0) device = 0 # unet = UNet().to(device) unet = UNet(n_class=1).to(device=cuda) criterion = nn.BCEWithLogitsLoss() learning_rates = 1e-5 optimizer = torch.optim.Adam(unet.parameters(), lr=args.learning_rate) unet.train() train_dataloader, val_dataloader = basic_dataloader(args.input_size, args.batch_size, args.num_workers) curr_lr = args.learning_rate print("Initializing Training!") # save_path = 'C:/Users/USER/Desktop/hand/model_save/' # dir_root = 'C:/Users/USER/Desktop/segmentation/' # dir_img = dir_root+ 'data/x/train/' # dir_mask = dir_root+ 'data/y/label/' dir_checkpoint = 'C:/Users/USER/Desktop/segmentation/checkpoints/'
def trainUnet(dirP, name, setLen, epochs=20): class WPCDEDataset(Dataset): def __init__(self, lenG, root_dir, transform=None): self.root_dir = root_dir self.lenG = lenG self.transform = transform def __len__(self): return self.lenG def __getitem__(self, idx): img_nameX = self.root_dir + '%dx.jpg' % (idx) img_nameY = self.root_dir + '%dy.jpg' % (idx) imageX = Image.open(img_nameX).convert('RGB') imageY = Image.open(img_nameY).convert('RGB') if self.transform: imageX = self.transform(imageX) imageY = self.transform(imageY) return imageX, imageY transf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainDatas = WPCDEDataset(setLen, dirP + r'\\' + 'train', transform=transf) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size = 16 lr = 2e-4 weight_decay = 0 start_epoch = 0 outf = r"..\model" unet = UNet(in_channels=3, out_channels=3) unet.to(device) optimizer = optim.Adam(list(unet.parameters()), lr=lr, weight_decay=weight_decay) dataloaderT = torch.utils.data.DataLoader(trainDatas, batch_size=batch_size, shuffle=True, num_workers=int(0)) # dataloaderV = torch.utils.data.DataLoader(valDatas, batch_size=batch_size, # shuffle=True, num_workers=int(0)) dataSplit = None # reserved rate of data for epoch in range(start_epoch, start_epoch + epochs): unet.train() for i, (x, y) in enumerate(dataloaderT): if dataSplit is not None: if i > len(dataloaderT) * dataSplit: break x = x.to(device) y = y.to(device) optimizer.zero_grad() ypred = unet(x) loss = F.mse_loss(y, ypred) loss.backward() optimizer.step() # break if (i) % int(len(dataloaderT) / 4) == 0: print('[%d/%d][%d/%d]\tLoss: %.4f\t ' % (epoch, start_epoch + epochs, i, len(dataloaderT), loss)) state = { 'model': unet.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(state, '%s/UnetS%d%sepoch%d.pth' % (outf, setLen, name, epoch))
## tensorboard 를 사용하기 위한 SummaryWriter 설정 writer_train = SummaryWriter(log_dir = os.path.join(log_dir, 'train')) writer_val = SummaryWriter(log_dir = os.path.join(log_dir, 'val')) ## 네트워크 학습시키기 st_epoch = 0 # TRAIN MODE if mode == 'train': if train_continue == "off": net, optim, st_epoch = load(ckpt_dir = ckpt_dir, net = net, optim = optim) for epoch in range(st_epoch + 1, num_epoch + 1): net.train() loss_arr = [] for batch, data in enumerate(loader_train, start = 1): # forward pass label = data['label'].to(device) input = data['input'].to(device) output = net(input) # backward pass optim.zero_grad() loss = fn_loss(output, label) loss.backward()
class Trainer(object): """Trainer for training and testing the model""" def __init__(self, data_loader, config): """Initialize configurations""" # model configuration self.in_dim = config.in_dim self.out_dim = config.out_dim self.num_filters = config.num_filters self.patch_size = config.patch_size # training configuration self.batch_size = config.batch_size self.num_iters = config.num_iters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.weight_decay = config.weight_decay self.resume_iters = config.resume_iters self.mode = config.mode # miscellaneous. self.use_tensorboard = config.use_tensorboard self.use_cuda = torch.cuda.is_available() self.device = torch.device('cuda:{}'.format(config.device_id) \ if self.use_cuda else 'cpu') # training result configuration self.log_dir = config.log_dir self.log_step = config.log_step self.model_save_dir = config.model_save_dir self.model_save_step = config.model_save_step # Build the model and tensorboard. self.build_model() if self.use_tensorboard: self.build_tensorboard() # data loader if self.mode == 'train' or self.mode == 'test': self.data_loader = data_loader else: self.train_data_loader, self.test_data_loader = data_loader def build_model(self): """Create a model""" self.model = UNet(self.in_dim, self.out_dim, self.num_filters) self.model = self.model.float() self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr, [self.beta1, self.beta2], weight_decay=self.weight_decay) self.print_network(self.model, 'unet') self.model.to(self.device) def _load(self, checkpoint_path): if self.use_cuda: checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def restore_model(self, resume_iters): """Restore the trained model""" print( 'Loading the trained models from step {}...'.format(resume_iters)) model_path = os.path.join(self.model_save_dir, '{}-unet'.format(resume_iters) + '.ckpt') checkpoint = self._load(model_path) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) def print_network(self, model, name): """Print out the network information""" num_params = 0 for p in model.parameters(): num_params += p.numel() #print(model) print(name) print("The number of parameters: {}".format(num_params)) def print_optimizer(self, opt, name): """Print out optimizer information""" print(opt) print(name) def build_tensorboard(self): """Build tensorboard for visualization""" from logger import Logger self.logger = Logger(self.log_dir) def reset_grad(self): """Reset the gradient buffers.""" self.optimizer.zero_grad() def train(self): """Train model""" if self.mode != 'train_test': data_loader = self.data_loader else: data_loader = self.train_data_loader print("current dataset size: ", len(data_loader)) data_iter = iter(data_loader) if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir) # start training from scratch or resume training. start_iters = 0 if self.resume_iters: print('Resuming ...') start_iters = self.resume_iters self.restore_model(self.resume_iters) self.print_optimizer(self.optimizer, 'optimizer') # print learning rate information lr = self.lr print('Current learning rates, g_lr: {}.'.format(lr)) # start training. print('Start training...') start_time = time.time() for i in range(start_iters, self.num_iters): # fetch batch data try: in_data, label = next(data_iter) except: data_iter = iter(data_loader) in_data, label, _, _, _ = next(data_iter) in_data = in_data.float().to(self.device) label = label.to(self.device) # train the model self.model = self.model.train() y_out = self.model(in_data) loss = nn.BCEWithLogitsLoss() output = loss(y_out, label) self.reset_grad() output.backward() self.optimizer.step() # logging if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}]".format( et, i + 1, self.num_iters) log += ", {}: {:.4f}".format("loss", output.mean().item()) print(log) if self.use_tensorboard: self.logger.scalar_summary("loss", output.mean().item(), i + 1) # save model checkpoints if (i + 1) % self.model_save_step == 0: path = os.path.join(self.model_save_dir, '{}-unet'.format(i + 1) + '.ckpt') torch.save( { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict() }, path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) def test(self): """Test model""" if self.mode != 'train_test': data_loader = self.data_loader else: data_loader = self.test_data_loader print("current dataset size: ", len(data_loader)) data_iter = iter(data_loader) # start testing on trained model if self.resume_iters and self.mode != 'train_test': print('Resuming ...') self.restore_model(self.resume_iters) # start testing. result, trace = np.zeros((78, 110, 24)), np.zeros((78, 110, 24)) print('Start testing...') correct, total, bcorrect = 0, 0, 0 while (True): # fetch batch data try: data_in, label, i, j, k = next(data_iter) except: break data_in = data_in.float().to(self.device) label = label.float().to(self.device) # test the model self.model = self.model.eval() y_hat = self.model(data_in) m = nn.Sigmoid() y_hat = m(y_hat) y_hat = y_hat.squeeze().detach().cpu().numpy() label = label.cpu().numpy().astype(int) y_hat_th = (y_hat > 0.2) label = (label > 0.5) test = (label == y_hat_th) correct += np.sum(test) btest = (label == 0) bcorrect += np.sum(btest) total += y_hat_th.size radius = int(self.patch_size / 2) for step in range(self.batch_size): x, y, z, pred = i[step], j[step], k[step], np.squeeze( y_hat_th[step, :, :, :]) result[x - radius:x + radius, y - radius:y + radius, z - radius:z + radius] += pred trace[x - radius:x + radius, y - radius:y + radius, z - radius:z + radius] += np.ones( (self.patch_size, self.patch_size, self.patch_size)) print('Accuracy: %.3f%%' % (correct / total * 100)) print('Baseline Accuracy: %.3f%%' % (bcorrect / total * 100)) trace += (trace == 0) result = result / trace scipy.io.savemat('prediction.mat', {'result': result}) def train_test(self): """Train and test model""" self.train() self.test()
model.load_state_dict( torch.load(os.path.join(args.load_model_dir, args.load_model_name))) initial_epoch = findLastCheckpoint( save_dir=save_dir) # load the last model in matconvnet style # initial_epoch = 150 if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) u_model.load_state_dict( torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))) # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) model.eval() u_model.train() criterion = nn.MSELoss() if cuda: model = model.cuda() u_model = u_model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2) # learning rates for epoch in range(initial_epoch, n_epoch): scheduler.step(epoch) # step to the learning rate in this epcoh xs = dg.datagenerator(data_dir=args.train_data) xs = xs.astype('float32') / 255.0 xs = torch.from_numpy(xs.transpose(
# Tensorboard train_writer = SummaryWriter(log_dir=TRAIN_LOG_DIR) val_writer = SummaryWriter(log_dir=VAL_LOG_DIR) # Training start_epoch = 0 # Load Checkpoint File if os.listdir(CKPT_DIR): net, optim, start_epoch = load_net(ckpt_dir=CKPT_DIR, net=net, optim=optim) else: print('* Training from scratch') num_epochs = cfg.NUM_EPOCHS for epoch in range(start_epoch + 1, num_epochs + 1): net.train() # Train Mode train_loss_arr = list() for batch_idx, data in enumerate(train_loader, 1): # Forward Propagation img = data['img'].to(device) label = data['label'].to(device) output = net(img) # Backward Propagation optim.zero_grad() loss = loss_fn(output, label) loss.backward()
def UNet(self): if self.mode == 'train': transform = transforms.Compose([Normalization(mean=0.5, std=0.5, mode='train'), ToTensor()]) dataset_train = Dataset(mode = self.mode, data_dir=self.data_dir, image_type = self.image_type, transform=transform) loader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=8) # dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform) # loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8) # 그밖에 부수적인 variables 설정하기 num_data_train = len(dataset_train) # num_data_val = len(dataset_val) num_batch_train = np.ceil(num_data_train / self.batch_size) # num_batch_val = np.ceil(num_data_val / batch_size) elif self.mode == 'test': transform = transforms.Compose([Normalization(mean=0.5, std=0.5, mode='test'), ToTensor()]) dataset_test = Dataset(mode = self.mode, data_dir=self.data_dir, image_type = self.image_type, transform=transform) loader_test = DataLoader(dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=8) # 그밖에 부수적인 variables 설정하기 num_data_test = len(dataset_test) num_batch_test = np.ceil(num_data_test / self.batch_size) fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1) fn_denorm = lambda x, mean, std: (x * std) + mean fn_class = lambda x: 1.0 * (x > 0.5) net = UNet().to(self.device) criterion = torch.nn.MSELoss().to(self.device) optimizer = torch.optim.Adam(net.parameters(), lr=self.lr) writer_train = SummaryWriter(log_dir=os.path.join(self.log_dir, 'train')) if self.mode == 'train': if self.train_continue == "on": net, optimizer = load_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer) for epoch in range(1, self.num_epoch + 1): net.train() loss_arr = [] for batch, data in enumerate(loader_train, 1): # forward pass label = data['label'].to(self.device) input = data['input'].to(self.device) output = net(input) # backward pass optimizer.zero_grad() loss = criterion(output, label) loss.backward() optimizer.step() # 손실함수 계산 loss_arr += [loss.item()] # Tensorboard 저장하기 label = fn_tonumpy(label) input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5)) output = fn_tonumpy(fn_class(output)) writer_train.add_image('label', label, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') writer_train.add_image('input', input, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') writer_train.add_image('output', output, num_batch_train * (epoch - 1) + batch, dataformats='NHWC') writer_train.add_scalar('loss', np.mean(loss_arr), epoch) print("TRAIN: EPOCH %04d / %04d | LOSS %.4f" %(epoch, self.num_epoch, np.mean(loss_arr))) if epoch % 20 == 0: save_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer, epoch=0) writer_train.close() # TEST MODE elif self.mode == 'test': net, optimizer = load_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer) with torch.no_grad(): net.eval() loss_arr = [] id = 1 for batch, data in enumerate(loader_test, 1): # forward pass input = data['input'].to(self.device) output = net(input) # 손실함수 계산하기 #loss = criterion(output, label) #loss_arr += [loss.item()] #print("TEST: BATCH %04d / %04d | " % # (batch, num_batch_test)) # Tensorboard 저장하기 output = fn_tonumpy(fn_class(output)) for j in range(input.shape[0]): if id == 800: id = 2350 print(id) #plt.imsave(os.path.join(self.result_dir, 'png', 'label_%04d.png' % id), label[j].squeeze(), cmap='gray') #plt.imsave(os.path.join(self.result_dir, 'png', 'input_%04d.png' % id), input[j].squeeze(), cmap='gray') plt.imsave(os.path.join(self.result_dir, 'png', 'gt%06d.png' % id), output[j].squeeze(), cmap='gray') id+=1 # np.save(os.path.join(result_dir, 'numpy', 'label_%04d.npy' % id), label[j].squeeze()) # np.save(os.path.join(result_dir, 'numpy', 'input_%04d.npy' % id), input[j].squeeze()) # np.save(os.path.join(result_dir, 'numpy', 'output_%04d.npy' % id), output[j].squeeze()) print("AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f" % (batch, num_batch_test, np.mean(loss_arr)))
def main(args): writer = SummaryWriter(os.path.join('./logs')) # torch.backends.cudnn.benchmark = False # if not os.path.isdir(args.checkpoint_dir): # os.mkdir(args.checkpoint_dir) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print('[MODEL] CUDA DEVICE : {}'.format(device)) # TODO DEFINE TRAIN AND TEST TRANSFORMS train_tf = None test_tf = None # Channel wise mean calculated on adobe240-fps training dataset mean = [0.429, 0.431, 0.397] std = [1, 1, 1] normalize = transforms.Normalize(mean=mean, std=std) transform = transforms.Compose([transforms.ToTensor(), normalize]) test_valid = 'validation' if args.valid else 'test' # train_data = BlurDataset(os.path.join(args.dataset_root, 'train'), # seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf) test_data = BlurDataset(os.path.join(args.dataset_root, test_valid), seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf, return_path=True) # train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False) # TODO IMPORT YOUR CUSTOM MODEL model = UNet(3, 3, device, decode_mode=args.decode_mode) if args.checkpoint: store_dict = torch.load(args.checkpoint) try: print('Loading checkpoint...') model.load_state_dict(store_dict['state_dict']) print('Done.') except KeyError: print('Loading checkpoint...') model.load_state_dict(store_dict) print('Done.') model.to(device) model.train(False) # model = nn.DataParallel(model) # TODO DEFINE MORE CRITERIA # input(True if device == torch.device('cuda:0') else False) criterion = { 'MSE': nn.MSELoss(), 'L1' : nn.L1Loss(), # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=False, # use_gpu=True if device == torch.device('cuda:0') else False) } # Validation running_loss_test = 0.0 psnr_test = 0.0 dssim_test = 0.0 tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150) loss_tracker_test = {} for loss_fn in criterion.keys(): loss_tracker_test[loss_fn] = 0.0 with torch.no_grad(): model.eval() total_steps_test = 0.0 interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49)) for test_idx, data in enumerate(test_loader, 1): loss = 0.0 blur_data, sharpe_data, sharp_names = data import pdb; pdb.set_trace() interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49)) # input(interp_idx) if args.decode_mode == 'interp': sharpe_data = sharpe_data[:, :, 1::2, :, :] elif args.decode_mode == 'deblur': sharpe_data = sharpe_data[:, :, 0::2, :, :] else: # print('\nBoth\n') sharpe_data = sharpe_data # print(sharpe_data.shape) # input(blur_data.shape) blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) try: sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) except: sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) # forward pass sharpe_out = model(blur_data).float() # compute losses sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4) B, C, S, Fx, Fy = sharpe_out.shape for loss_fn in criterion.keys(): loss_tmp = 0.0 if loss_fn == 'Perceptual': for bidx in range(B): loss_tmp += criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3), sharpe_data[bidx].permute(1, 0, 2, 3)).sum() # loss_tmp /= B else: loss_tmp = criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out, sharpe_data) loss += loss_tmp try: loss_tracker_test[loss_fn] += loss_tmp.item() except KeyError: loss_tracker_test[loss_fn] = loss_tmp.item() # statistics #sharpe_out = sharpe_out.detach().cpu().numpy() #sharpe_data = sharpe_data.cpu().numpy() # import pdb; pdb.set_trace() # t_grid = torchvision.utils.make_grid(torch.stack([blur_data[0], sharpe_out[0], sharpe_data[0]], dim=0), # nrow=3) # tsave(t_grid, './imgs/{}/combined.jpg'.format(test_idx)) for sidx in range(S): for bidx in range(B): if not os.path.exists('./imgs/{}'.format(sharp_names[1])): os.makedirs('./imgs/{}'.format(test_idx)) blur_path = './imgs/{}/blur_input_{}.jpg'.format(test_idx, sidx) # import pdb; pdb.set_trace() # torchvision.utils.save_image(sharpe_out[bidx, :, sidx, :, :],blur_path, normalize=True, range=(0,255)); imsave(blur_data, blur_path, bidx, sidx) sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx, sidx) imsave(sharpe_data, sharp_path, bidx, sidx) deblur_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx, sidx) imsave(sharpe_out, deblur_path, bidx, sidx) if sidx > 0 and sidx < S: interp_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx-1, sidx) imsave(sharpe_out, interp_path, bidx, sidx) sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx-1, sidx) imsave(sharpe_data, sharp_path, bidx, sidx) psnr_local = psnr(im_nm * sharpe_out[bidx, :, sidx, :, :].detach().cpu().numpy(), im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy()) dssim_local = dssim(np.moveaxis(im_nm * sharpe_out[bidx, :, sidx, :, :].cpu().numpy(), 0, 2), np.moveaxis(im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy(), 0, 2) ) psnr_test += psnr_local dssim_test += dssim_local f = open('./imgs/{0}/psnr-{1:.4f}-dssim-{2:.4f}.txt'.format(test_idx, psnr_local/(B), dssim_local/(B)),'w') f.close() running_loss_test += loss.item() total_steps_test += B*S loss_str = '' for key in loss_tracker_test.keys(): loss_str += ' {0} : {1:6.4f} '.format(key, 1.0 * loss_tracker_test[key] / total_steps_test) # set display info tqdm_loader_test.set_description( ('\r[Test ] loss: {0:6.4f} PSNR: {1:6.4f} SSIM: {2:6.4f} '.format ( running_loss_test / total_steps_test, psnr_test / total_steps_test, dssim_test / total_steps_test ) + loss_str ) ) tqdm_loader_test.update(1) tqdm_loader_test.close() return None
def main(args): writer = SummaryWriter(os.path.join('./logs')) # torch.backends.cudnn.benchmark = True if not os.path.isdir(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print('[MODEL] CUDA DEVICE : {}'.format(device)) # TODO DEFINE TRAIN AND TEST TRANSFORMS train_tf = None test_tf = None # Channel wise mean calculated on adobe240-fps training dataset mean = [0.429, 0.431, 0.397] std = [1, 1, 1] normalize = transforms.Normalize(mean=mean, std=std) transform = transforms.Compose([transforms.ToTensor(), normalize]) test_valid = 'validation' if args.valid else 'test' train_data = BlurDataset(os.path.join(args.dataset_root, 'train'), seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf) test_data = BlurDataset(os.path.join(args.dataset_root, test_valid), seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf) train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False) # TODO IMPORT YOUR CUSTOM MODEL model = UNet(3, 3, device, decode_mode=args.decode_mode) if args.checkpoint: store_dict = torch.load(args.checkpoint) try: model.load_state_dict(store_dict['state_dict']) except KeyError: model.load_state_dict(store_dict) if args.train_continue: store_dict = torch.load(args.checkpoint) model.load_state_dict(store_dict['state_dict']) else: store_dict = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1} model.to(device) model.train(True) # model = nn.DataParallel(model) # TODO DEFINE MORE CRITERIA # input(True if device == torch.device('cuda:0') else False) criterion = { 'MSE': nn.MSELoss(), 'L1': nn.L1Loss(), # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=True, # use_gpu=True if device == torch.device('cuda:0') else False) } criterion_w = {'MSE': 1.0, 'L1': 10.0, 'Perceptual': 10.0} # Define optimizers # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4) optimizer = optim.Adam(model.parameters(), lr=args.init_learning_rate) # Define lr scheduler scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1) # best_acc = 0.0 # start = time.time() cLoss = store_dict['loss'] valLoss = store_dict['valLoss'] valPSNR = store_dict['valPSNR'] checkpoint_counter = 0 loss_tracker = {} loss_tracker_test = {} psnr_old = 0.0 dssim_old = 0.0 for epoch in range(1, 10 * args.epochs): # loop over the dataset multiple times # Append and reset cLoss.append([]) valLoss.append([]) valPSNR.append([]) running_loss = 0 # Increment scheduler count scheduler.step() tqdm_loader = tqdm(range(len(train_loader)), ncols=150) loss = 0.0 psnr_ = 0.0 dssim_ = 0.0 loss_tracker = {} for loss_fn in criterion.keys(): loss_tracker[loss_fn] = 0.0 # Train model.train(True) total_steps = 0.01 total_steps_test = 0.01 '''for train_idx, data in enumerate(train_loader, 1): loss = 0.0 blur_data, sharpe_data = data #import pdb; pdb.set_trace() # input(sharpe_data.shape) #import pdb; pdb.set_trace() interp_idx = int(math.ceil((args.num_frame_blur/2) - 0.49)) #input(interp_idx) if args.decode_mode == 'interp': sharpe_data = sharpe_data[:, :, 1::2, :, :] elif args.decode_mode == 'deblur': sharpe_data = sharpe_data[:, :, 0::2, :, :] else: #print('\nBoth\n') sharpe_data = sharpe_data #print(sharpe_data.shape) #input(blur_data.shape) blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) try: sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) except: sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) # clear gradient optimizer.zero_grad() # forward pass sharpe_out = model(blur_data) # import pdb; pdb.set_trace() # input(sharpe_out.shape) # compute losses # import pdb; # pdb.set_trace() sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4) B, C, S, Fx, Fy = sharpe_out.shape for loss_fn in criterion.keys(): loss_tmp = 0.0 if loss_fn == 'Perceptual': for bidx in range(B): loss_tmp += criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3), sharpe_data[bidx].permute(1, 0, 2, 3)).sum() # loss_tmp /= B else: loss_tmp = criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out, sharpe_data) # try: # import pdb; pdb.set_trace() loss += loss_tmp # if # except : try: loss_tracker[loss_fn] += loss_tmp.item() except KeyError: loss_tracker[loss_fn] = loss_tmp.item() # Backpropagate loss.backward() optimizer.step() # statistics # import pdb; pdb.set_trace() sharpe_out = sharpe_out.detach().cpu().numpy() sharpe_data = sharpe_data.cpu().numpy() for sidx in range(S): for bidx in range(B): psnr_ += psnr(sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0) """dssim_ += dssim(np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2), np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2) )""" """sharpe_out = sharpe_out.reshape(-1,3, sx, sy).detach().cpu().numpy() sharpe_data = sharpe_data.reshape(-1, 3, sx, sy).cpu().numpy() for idx in range(sharpe_out.shape[0]): # import pdb; pdb.set_trace() psnr_ += psnr(sharpe_data[idx], sharpe_out[idx]) dssim_ += dssim(np.swapaxes(sharpe_data[idx], 2, 0), np.swapaxes(sharpe_out[idx], 2, 0))""" # psnr_ /= sharpe_out.shape[0] # dssim_ /= sharpe_out.shape[0] running_loss += loss.item() loss_str = '' total_steps += B*S for key in loss_tracker.keys(): loss_str += ' {0} : {1:6.4f} '.format(key, 1.0*loss_tracker[key] / total_steps) # set display info if train_idx % 5 == 0: tqdm_loader.set_description(('\r[Training] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '.format (epoch, running_loss / total_steps, psnr_ / total_steps, dssim_ / total_steps) + loss_str )) tqdm_loader.update(5) tqdm_loader.close()''' # Validation running_loss_test = 0.0 psnr_test = 0.0 dssim_test = 0.0 # print('len', len(test_loader)) tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150) # import pdb; pdb.set_trace() loss_tracker_test = {} for loss_fn in criterion.keys(): loss_tracker_test[loss_fn] = 0.0 with torch.no_grad(): model.eval() total_steps_test = 0.0 for test_idx, data in enumerate(test_loader, 1): loss = 0.0 blur_data, sharpe_data = data interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49)) # input(interp_idx) if args.decode_mode == 'interp': sharpe_data = sharpe_data[:, :, 1::2, :, :] elif args.decode_mode == 'deblur': sharpe_data = sharpe_data[:, :, 0::2, :, :] else: # print('\nBoth\n') sharpe_data = sharpe_data # print(sharpe_data.shape) # input(blur_data.shape) blur_data = blur_data.to(device)[:, :, :, :352, :].permute( 0, 1, 2, 4, 3) try: sharpe_data = sharpe_data.squeeze().to( device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) except: sharpe_data = sharpe_data.squeeze(3).to( device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3) # clear gradient optimizer.zero_grad() # forward pass sharpe_out = model(blur_data) # import pdb; pdb.set_trace() # input(sharpe_out.shape) # compute losses sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4) B, C, S, Fx, Fy = sharpe_out.shape for loss_fn in criterion.keys(): loss_tmp = 0.0 if loss_fn == 'Perceptual': for bidx in range(B): loss_tmp += criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3), sharpe_data[bidx].permute(1, 0, 2, 3)).sum() # loss_tmp /= B else: loss_tmp = criterion_w[loss_fn] * \ criterion[loss_fn](sharpe_out, sharpe_data) loss += loss_tmp try: loss_tracker_test[loss_fn] += loss_tmp.item() except KeyError: loss_tracker_test[loss_fn] = loss_tmp.item() if ((test_idx % args.progress_iter) == args.progress_iter - 1): itr = test_idx + epoch * len(test_loader) # itr_train writer.add_scalars( 'Loss', { 'trainLoss': running_loss / total_steps, 'validationLoss': running_loss_test / total_steps_test }, itr) writer.add_scalar('Train PSNR', psnr_ / total_steps, itr) writer.add_scalar('Test PSNR', psnr_test / total_steps_test, itr) # import pdb; pdb.set_trace() # writer.add_image('Validation', sharpe_out.permute(0, 2, 3, 1), itr) # statistics sharpe_out = sharpe_out.detach().cpu().numpy() sharpe_data = sharpe_data.cpu().numpy() for sidx in range(S): for bidx in range(B): psnr_test += psnr( sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0) dssim_test += dssim( np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2), np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2)) #,range=1.0 ) running_loss_test += loss.item() total_steps_test += B * S loss_str = '' for key in loss_tracker.keys(): loss_str += ' {0} : {1:6.4f} '.format( key, 1.0 * loss_tracker_test[key] / total_steps_test) # set display info tqdm_loader_test.set_description(( '\r[Test ] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} ' .format(epoch, running_loss_test / total_steps_test, psnr_test / total_steps_test, dssim_test / total_steps_test) + loss_str)) tqdm_loader_test.update(1) tqdm_loader_test.close() # save model if psnr_old < (psnr_test / total_steps_test): if epoch != 1: os.remove( os.path.join( args.checkpoint_dir, 'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format( epoch_old, str(round(psnr_old, 4)).replace('.', 'pt'), str(round(dssim_old, 4)).replace('.', 'pt')))) epoch_old = epoch psnr_old = psnr_test / total_steps_test dssim_old = dssim_test / total_steps_test checkpoint_dict = { 'epoch': epoch_old, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'train_psnr': psnr_ / total_steps, 'train_dssim': dssim_ / total_steps, 'train_mse': loss_tracker['MSE'] / total_steps, 'train_l1': loss_tracker['L1'] / total_steps, # 'train_percp': loss_tracker['Perceptual'] / total_steps, 'test_psnr': psnr_old, 'test_dssim': dssim_old, 'test_mse': loss_tracker_test['MSE'] / total_steps_test, 'test_l1': loss_tracker_test['L1'] / total_steps_test, # 'test_percp': loss_tracker_test['Perceptual'] / total_steps_test, } torch.save( checkpoint_dict, os.path.join( args.checkpoint_dir, 'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format( epoch_old, str(round(psnr_old, 4)).replace('.', 'pt'), str(round(dssim_old, 4)).replace('.', 'pt')))) # if epoch % args.checkpoint_epoch == 0: # torch.save(model.state_dict(),args.checkpoint_dir + str(int(epoch/100))+".ckpt") return None
def train(): if not os.path.exists("train_model/"): os.makedirs("train_model/") if not os.path.exists("result/"): os.makedirs("result/") train_data, dev_data, word2id, id2word, char2id, opts = load_data( vars(args)) model = UNet(opts) if args.use_cuda: model = model.cuda() dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) if args.eval: print("load model...") model.load_state_dict(torch.load(args.model_dir)) model.eval() model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file="result/" + args.model_dir.split("/")[-1] + ".answers", drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) exit() if args.load_model: print("load model...") model.load_state_dict(torch.load(args.model_dir)) model.eval() _, F1 = model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file=os.path.join("result/", args.model_dir.split("/")[-1], ".answers"), drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) best_score = F1 with open(args.model_dir + "_f1_scores.pkl", "rb") as f: f1_scores = pkl.load(f) with open(args.model_dir + "_em_scores.pkl", "rb") as f: em_scores = pkl.load(f) else: best_score = 0.0 f1_scores = [] em_scores = [] parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adamax(parameters, lr=args.lrate) lrate = args.lrate for epoch in range(1, args.epochs + 1): train_batches = get_batches(train_data, args.batch_size) dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) total_size = len(train_data) // args.batch_size model.train() for i, train_batch in enumerate(train_batches): loss = model(train_batch) model.zero_grad() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(parameters, opts["grad_clipping"]) optimizer.step() model.reset_parameters() if i % 100 == 0: print( "Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f" % (epoch, i, total_size, model.train_loss.value, lrate, best_score)) sys.stdout.flush() model.eval() exact_match_score, F1 = model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file=os.path.join("result/", args.model_dir.split("/")[-1], ".answers"), drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) f1_scores.append(F1) em_scores.append(exact_match_score) with open(args.model_dir + "_f1_scores.pkl", "wb") as f: pkl.dump(f1_scores, f) with open(args.model_dir + "_em_scores.pkl", "wb") as f: pkl.dump(em_scores, f) if best_score < F1: best_score = F1 print("saving %s ..." % args.model_dir) torch.save(model.state_dict(), args.model_dir) if epoch > 0 and epoch % args.decay_period == 0: lrate *= args.decay for param_group in optimizer.param_groups: param_group["lr"] = lrate