def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model model, loss_weights = compile_model(args=args, train_list=train_list, net_input_shape=net_input_shape, uncomp_model=u_model) # Load pre-trained weights if args.custom_weights_path != '': try: model.load_weights(args.custom_weights_path) except Exception as e: print(e) print('!!! Failed to load weights file. Training without pre-training weights. !!!') # Set the callbacks callbacks = get_callbacks(args) # Training the network history = model.fit_generator( generate_train_batches(root_path=args.data_root_dir, train_list=train_list, net_shape=net_input_shape, mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net, MIP_choices=args.MIP_choices, n_class=args.num_classes, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=int(np.ceil(len(train_list)/args.batch_size*12)), # 12 avg. num of loops in train generator validation_data=generate_val_batches(root_path=args.data_root_dir, val_list=val_list, net_shape=net_input_shape, mod_dirs=args.modality_dir_list, exp_name=args.exp_name, net=args.net, MIP_choices=args.MIP_choices, n_class=args.num_classes, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=args.stride, shuff=args.shuffle_data), validation_steps=int(np.ceil(len(val_list)/args.batch_size)), epochs=args.epochs, class_weight=loss_weights, callbacks=callbacks, verbose=args.verbose) # Plot the training data collected plot_training(history, args.net, args.num_classes, args.output_dir, args.output_name, args.time)
def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model num_epoch = 400 model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) loss_vec = np.zeros((num_epoch, 1)) dice_hard_vec = np.zeros((num_epoch, 1)) val_loss_vec = np.zeros((num_epoch, 1)) val_dice_hard_vec = np.zeros((num_epoch, 1)) val_out_seg_loss_vec = np.zeros((num_epoch, 1)) out_seg_loss_vec = np.zeros((num_epoch, 1)) val_out_recon_loss_vec = np.zeros((num_epoch, 1)) out_recon_loss_vec = np.zeros((num_epoch, 1)) # Set the callbacks for i in range(num_epoch): print(i) callbacks = get_callbacks(args, i) # Training the network train_batches = generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, image_shape=net_input_shape, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data) val_data = generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=20, shuff=args.shuffle_data) history = model.fit_generator(train_batches.it, max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=247, validation_data=val_data.it, validation_steps=200, callbacks=callbacks, verbose=1) # Plot the training data collected if args.net == 'segcapsr3': # Plot the training data collected print(history.history.keys()) loss_vec[i] = history.history['loss'][0] dice_hard_vec[i] = history.history['out_seg_dice_hard'][0] val_loss_vec[i] = history.history['val_loss'][0] val_dice_hard_vec[i] = history.history['val_out_seg_dice_hard'][0] val_out_seg_loss_vec[i] = history.history['val_out_seg_loss'][0] out_seg_loss_vec[i] = history.history['out_seg_loss'][0] val_out_recon_loss_vec[i] = history.history['val_out_recon_loss'][0] out_recon_loss_vec[i] = history.history['out_recon_loss'][0] file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+") file2.writelines( [str(loss_vec), str(dice_hard_vec), str(out_seg_loss_vec), str(out_recon_loss_vec), str(val_loss_vec), str(val_dice_hard_vec), str(val_out_seg_loss_vec), str(val_out_recon_loss_vec)]) file2.close() else: loss_vec[i] = history.history['loss'][0] dice_hard_vec[i] = history.history['dice_hard'][0] val_loss_vec[i] = history.history['val_loss'][0] val_dice_hard_vec[i] = history.history['val_dice_hard'][0] file2 = open(join(args.output_dir, '_errors_' + str(i) + '.txt'), "w+") file2.writelines([str(loss_vec), str(dice_hard_vec), str(val_loss_vec), str(val_dice_hard_vec)]) file2.close() plot_training(loss_vec, dice_hard_vec, val_loss_vec, val_dice_hard_vec, args, num_epoch)
def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) weights_path = args.weights_path if args.retrain == 1: print('\nRetrain model from weights_path=%s' % (weights_path)) model.load_weights(weights_path) # Set the callbacks callbacks = get_callbacks(args) # Training the network history = model.fit( generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=args.steps, validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=20, shuff=args.shuffle_data), validation_steps= 250, # Set validation stride larger to see more of the data. epochs=args.epochs, callbacks=callbacks, verbose=1) # Plot the training data collected plot_training(history, args)
def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) # Set the callbacks callbacks = get_callbacks(args) # print(callbacks) train_batches = generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data) val_batches = generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=20, shuff=args.shuffle_data) print("train_batches..", train_batches) # Training the network history = model.fit_generator( train_batches, max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=10000, validation_data=val_batches, validation_steps= 500, # Set validation stride larger to see more of the data. epochs=200, callbacks=callbacks, verbose=1) # Plot the training data collected plot_training(history, args)
def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) # Set the callbacks callbacks = get_callbacks(args) # Training the network history = model.fit_generator( generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=10000, validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=20, shuff=args.shuffle_data), validation_steps=500, # Set validation stride larger to see more of the data. epochs=200, callbacks=callbacks, verbose=1) # Plot the training data collected plot_training(history, args)
def train(args, train_list, val_list, u_model, net_input_shape): # Compile the loaded model model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) # Set the callbacks callbacks = get_callbacks(args) # Training the network history = model.fit( generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batch_size=args.batch_size, shuff=args.shuffle_data, aug_data=args.aug_data), max_queue_size=40, workers=0, use_multiprocessing=False, steps_per_epoch=ceil(len(train_list) / args.batch_size), validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batch_size=args.batch_size, shuff=args.shuffle_data), validation_steps=ceil( len(val_list) / args.batch_size ), # Set validation stride larger to see more of the data. epochs=50, callbacks=callbacks, verbose=1) # Plot the training data collected plot_training(history, args)
def train(args, model, train_list, net_input_shape): """ # Compile the loaded model model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) # Set the callbacks callbacks = get_callbacks(args) # Training the network history = model.fit_generator( generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), max_queue_size=40, workers=4, use_multiprocessing=False, steps_per_epoch=10000, validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, stride=20, shuff=args.shuffle_data), validation_steps=500, # Set validation stride larger to see more of the data. epochs=200, callbacks=callbacks, verbose=1) """ model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr, betas=(0.99, 0.999)) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.05, patience=5, verbose=True) loss, loss_weighting = get_loss(root=args.data_root_dir, split=args.split_num, net=args.net, recon_wei=args.recon_wei, choice=args.loss) recon_loss = nn.MSELoss(reduction='sum') fit_generator = generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data) factor = 1. for i, batch in enumerate(fit_generator): x = torch.from_numpy(batch[0][0]).float().permute( 0, 3, 1, 2) # -> (1,512,512,1) x = Variable(x).cuda() x1 = torch.from_numpy(batch[0][1]).float().permute( 0, 3, 1, 2) # -> (1,512,512,1) x1 = Variable(x1).cuda() y = torch.from_numpy(batch[1][0]).float().permute( 0, 3, 1, 2) # -> (1,512,512,1) y = y.cuda() label_recon = torch.from_numpy(batch[1][1]).float().permute( 0, 3, 1, 2) # -> (1,512,512,1) label_recon = label_recon.cuda() optimizer.zero_grad() out_seg, out_recon = model(x, x1) loss_a = loss['out_seg'](y, out_seg) if (i % 100) == 0: save_image(x, 'x.jpg') save_image(y, 'y.jpg') save_image(label_recon, 'label_recon.jpg') save_image(out_seg, 'out_seg.jpg') save_image(loss_a, 'loss_a.jpg') save_image(out_recon, 'out_recon.jpg') loss_a = loss_a.sum() loss_b = recon_loss(out_recon, label_recon) total_loss = loss_a + factor * loss_b total_loss.backward() optimizer.step() print("batch:", i, "loss a:", loss_a.data.item(), "loss b:", loss_b.data.item(), "check:", out_recon.sum().data.item()) """