def train(args, arguments): batch_time = Utilities.AverageMeter() losses = Utilities.AverageMeter() # switch to train mode arguments['model'].train() end = time.time() train_loader_len = int( math.ceil(arguments['TADL'].shard_size / args.batch_size)) i = 0 arguments['TADL'].reset_avail_winds(arguments['epoch']) while i * arguments['TADL'].batch_size < arguments['TADL'].shard_size: # get the noisy inputs and the labels _, inputs, _, _, labels = arguments['TADL'].get_batch( descart_empty_windows=False) mean = torch.mean(inputs, 1, True) inputs = inputs - mean # zero the parameter gradients arguments['optimizer'].zero_grad() # forward + backward + optimize inputs = inputs.unsqueeze(1) outputs = arguments['model'](inputs) # Compute Huber loss loss = F.smooth_l1_loss(outputs[:, 0], labels[:, 0]) # Adjust learning rate #Model_Util.learning_rate_schedule(args, arguments) # compute gradient and do SGD step loss.backward() arguments['optimizer'].step() #if args.test: #if i > 10: #break if i % args.print_freq == 0 and i != 0: # Every print_freq iterations, check the loss and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Average loss across processes for logging if args.distributed: reduced_loss = Utilities.reduce_tensor(loss.data, args.world_size) else: reduced_loss = loss.data # to_python_float incurs a host<->device sync losses.update(Utilities.to_python_float(reduced_loss), args.batch_size) if not args.cpu: torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq, args.print_freq) end = time.time() if args.local_rank == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})'.format( arguments['epoch'], i, train_loader_len, args.world_size * args.batch_size / batch_time.val, args.world_size * args.batch_size / batch_time.avg, batch_time=batch_time, loss=losses)) i += 1 arguments['loss_history'].append(losses.avg) return batch_time.sum, batch_time.avg
def validate(args, arguments): miscounted_pulses = torch.tensor(0) number_of_pulses = torch.tensor(0) batch_time = Utilities.AverageMeter() average_counter_error = Utilities.AverageMeter() # switch to evaluate mode arguments['model'].eval() end = time.time() val_loader_len = int( math.ceil(arguments['VADL'].shard_size / args.batch_size)) i = 0 arguments['VADL'].reset_avail_winds(arguments['epoch']) while i * arguments['VADL'].batch_size < arguments['VADL'].shard_size: # bring a new batch times, noisy_signals, clean_signals, _, labels = arguments[ 'VADL'].get_batch(descart_empty_windows=False) mean = torch.mean(noisy_signals, 1, True) noisy_signals = noisy_signals - mean with torch.no_grad(): noisy_signals = noisy_signals.unsqueeze(1) outputs = arguments['model'](noisy_signals) noisy_signals = noisy_signals.squeeze(1) denominator = (abs(labels[:, 0].to('cpu')) + abs(outputs[:, 0].data.to('cpu'))) / 2 errors = abs(labels[:, 0].to('cpu') - outputs[:, 0].data.to('cpu')) / denominator errors = torch.mean(errors, dim=0) counter_error = errors if args.evaluate: miscounted_pulses += abs( round(labels[:, 0].to('cpu').sum(dim=0).item()) - round(outputs[:, 0].data.to('cpu').sum(dim=0).item())) number_of_pulses += round(labels[:, 0].to('cpu').sum(dim=0).item()) if args.distributed: reduced_counter_error = Utilities.reduce_tensor( counter_error.data, args.world_size) else: reduced_counter_error = counter_error.data average_counter_error.update( Utilities.to_python_float(reduced_counter_error), args.batch_size) # measure elapsed time batch_time.update(time.time() - end) end = time.time() #if args.test: #if i > 10: #break if args.local_rank == 0 and i % args.print_freq == 0 and i != 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {2:.3f} ({3:.3f})\t' 'Counter Error {dur_error.val:.4f} ({dur_error.avg:.4f})'. format(i, val_loader_len, args.world_size * args.batch_size / batch_time.val, args.world_size * args.batch_size / batch_time.avg, batch_time=batch_time, dur_error=average_counter_error)) i += 1 if not args.evaluate: arguments['counter_error_history'].append(average_counter_error.avg) if args.evaluate: if args.distributed: reduced_miscounted_pulses = Utilities.reduce_tensor_sum( miscounted_pulses.data) reduced_number_of_pulses = Utilities.reduce_tensor_sum( number_of_pulses.data) else: reduced_miscounted_pulses = miscounted_pulses.data reduced_number_of_pulses = number_of_pulses.data print( '##We have {} miscounted pulses from a total of {} pulses'.format( reduced_miscounted_pulses, reduced_number_of_pulses)) return average_counter_error.avg
def main(): global best_error, args best_error = math.inf args = parse() if not len(args.data): raise Exception("error: No data set provided") if not len(args.counter): raise Exception("error: No path to counter model provided") if not len(args.predictor): raise Exception("error: No path to predictor model provided") args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank if not args.cpu: torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='gloo', init_method='env://') args.world_size = torch.distributed.get_world_size() args.total_batch_size = args.world_size * args.batch_size # Set the device device = torch.device('cpu' if args.cpu else 'cuda:' + str(args.gpu)) # create model_1 if args.test: args.arch_1 = 'ResNet10' if args.local_rank == 0: print("=> creating model_1 '{}'".format(args.arch_1)) if args.arch_1 == 'ResNet18': model_1 = rn.ResNet18_Counter() elif args.arch_1 == 'ResNet34': model_1 = rn.ResNet34_Counter() elif args.arch_1 == 'ResNet50': model_1 = rn.ResNet50_Counter() elif args.arch_1 == 'ResNet101': model_1 = rn.ResNet101_Counter() elif args.arch_1 == 'ResNet152': model_1 = rn.ResNet152_Counter() elif args.arch_1 == 'ResNet10': model_1 = rn.ResNet10_Counter() else: print("Unrecognized {} for translocations counter architecture".format( args.arch_1)) # create model_2 if args.test: args.arch_2 = 'ResNet10' if args.local_rank == 0: print("=> creating model_2 '{}'".format(args.arch_2)) if args.arch_2 == 'ResNet18': model_2 = rn.ResNet18_Custom() elif args.arch_2 == 'ResNet34': model_2 = rn.ResNet34_Custom() elif args.arch_2 == 'ResNet50': model_2 = rn.ResNet50_Custom() elif args.arch_2 == 'ResNet101': model_2 = rn.ResNet101_Custom() elif args.arch_2 == 'ResNet152': model_2 = rn.ResNet152_Custom() elif args.arch_2 == 'ResNet10': model_2 = rn.ResNet10_Custom() else: print( "Unrecognized {} for translocation feature prediction architecture" .format(args.arch_2)) model_1 = model_1.to(device) model_2 = model_2.to(device) # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel. if args.distributed: if args.cpu: model_1 = DDP(model_1) model_2 = DDP(model_2) else: model_1 = DDP(model_1, device_ids=[args.gpu], output_device=args.gpu) model_2 = DDP(model_2, device_ids=[args.gpu], output_device=args.gpu) if args.verbose: print( 'Since we are in a distributed setting the model is replicated here in local rank {}' .format(args.local_rank)) total_time = Utilities.AverageMeter() # bring counter from a checkpoint if args.counter: # Use a local scope to avoid dangling references def bring_counter(): if os.path.isfile(args.counter): print("=> loading counter '{}'".format(args.counter)) if args.cpu: checkpoint = torch.load(args.counter, map_location='cpu') else: checkpoint = torch.load(args.counter, map_location=lambda storage, loc: storage.cuda(args.gpu)) loss_history_1 = checkpoint['loss_history'] counter_error_history = checkpoint['Counter_error_history'] best_error_1 = checkpoint['best_error'] model_1.load_state_dict(checkpoint['state_dict']) total_time_1 = checkpoint['total_time'] print("=> loaded counter '{}' (epoch {})".format( args.counter, checkpoint['epoch'])) print("Model best precision saved was {}".format(best_error_1)) return best_error_1, model_1, loss_history_1, counter_error_history, total_time_1 else: print("=> no counter found at '{}'".format(args.counter)) best_error_1, model_1, loss_history_1, counter_error_history, total_time_1 = bring_counter( ) else: raise Exception("error: No counter path provided") # bring predictor from a checkpoint if args.predictor: # Use a local scope to avoid dangling references def bring_predictor(): if os.path.isfile(args.predictor): print("=> loading predictor '{}'".format(args.predictor)) if args.cpu: checkpoint = torch.load(args.predictor, map_location='cpu') else: checkpoint = torch.load(args.predictor, map_location=lambda storage, loc: storage.cuda(args.gpu)) loss_history_2 = checkpoint['loss_history'] duration_error_history = checkpoint['duration_error_history'] amplitude_error_history = checkpoint['amplitude_error_history'] best_error_2 = checkpoint['best_error'] model_2.load_state_dict(checkpoint['state_dict']) total_time_2 = checkpoint['total_time'] print("=> loaded predictor '{}' (epoch {})".format( args.predictor, checkpoint['epoch'])) print("Model best precision saved was {}".format(best_error_2)) return best_error_2, model_2, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 else: print("=> no predictor found at '{}'".format(args.predictor)) best_error_2, model_2, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 = bring_predictor( ) else: raise Exception("error: No predictor path provided") # plots validation stats from a file if args.stats_from_file and args.local_rank == 0: # Use a local scope to avoid dangling references def bring_stats_from_file(): if os.path.isfile(args.stats_from_file): print("=> loading stats from file '{}'".format( args.stats_from_file)) if args.cpu: stats = torch.load(args.stats_from_file, map_location='cpu') else: stats = torch.load(args.stats_from_file, map_location=lambda storage, loc: storage.cuda(args.gpu)) count_errors = stats['count_errors'] duration_errors = stats['duration_errors'] amplitude_errors = stats['amplitude_errors'] Cnp = stats['Cnp'] Duration = stats['Duration'] Dnp = stats['Dnp'] Arch = stats['Arch'] print("=> loaded stats '{}'".format(args.stats_from_file)) return count_errors, duration_errors, amplitude_errors, Cnp, Duration, Dnp, Arch else: print("=> no stats found at '{}'".format(args.stats_from_file)) count_errors, duration_errors, amplitude_errors, Cnp, Duration, Dnp, Arch = bring_stats_from_file( ) plot_stats(Cnp, Duration, Dnp, count_errors, duration_errors, amplitude_errors) return # Data loading code valdir = os.path.join(args.data, 'test') if args.test: validation_f = h5py.File(valdir + '/test_toy.h5', 'r') else: validation_f = h5py.File(valdir + '/test.h5', 'r') # this is the dataset for validating sampling_rate = 10000 # This is the number of samples per second of the signals in the dataset if args.test: number_of_concentrations = 2 # This is the number of different concentrations in the dataset number_of_durations = 2 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 4 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration else: number_of_concentrations = 20 # This is the number of different concentrations in the dataset number_of_durations = 5 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 15 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration # Validating Artificial Data Loader VADL = Artificial_DataLoader(args.world_size, args.local_rank, device, validation_f, sampling_rate, number_of_concentrations, number_of_durations, number_of_diameters, window, length, args.batch_size) if args.verbose: print('From rank {} validation shard size is {}'.format( args.local_rank, VADL.get_number_of_avail_windows())) if args.run: arguments = { 'model_1': model_1, 'model_2': model_2, 'device': device, 'epoch': 0, 'VADL': VADL } if args.local_rank == 0: run_model(args, arguments) return if args.statistics: arguments = { 'model_1': model_1, 'model_2': model_2, 'device': device, 'epoch': 0, 'VADL': VADL } [count_errors, duration_errors, amplitude_errors, improper_measures] = compute_error_stats(args, arguments) if args.local_rank == 0: (Cnp, Duration, Dnp) = VADL.shape[:3] plot_stats(Cnp, Duration, Dnp, count_errors, duration_errors, amplitude_errors) print("This backbone produces {} improper measures.\nImproper measures are produced when the ground truth establishes 0 number of pulses but the network predicts one or more pulses."\ .format(improper_measures)) if args.save_stats: Model_Util.save_stats( { 'count_errors': count_errors, 'duration_errors': duration_errors, 'amplitude_errors': amplitude_errors, 'Cnp': VADL.shape[0], 'Duration': VADL.shape[1], 'Dnp': VADL.shape[2], 'Arch': args.arch_2 }, args.save_stats) return if args.output_statistics: arguments = { 'model_1': model_1, 'model_2': model_2, 'device': device, 'epoch': 0, 'VADL': VADL } [counts, durations, amplitudes] = compute_output_stats(args, arguments) if args.local_rank == 0: (Cnp, Duration, Dnp) = VADL.shape[:3] plot_stats(Cnp, Duration, Dnp, counts, durations, amplitudes, Error=False) if args.save_stats: Model_Util.save_stats( { 'counts': counts, 'durations': durations, 'amplitudes': amplitudes, 'Cnp': VADL.shape[0], 'Duration': VADL.shape[1], 'Dnp': VADL.shape[2], 'Arch': args.arch_2 }, args.save_stats) return
def main(): global best_error, args best_error = math.inf args = parse() if not len(args.data): raise Exception("error: No data set provided") args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank if not args.cpu: torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='gloo', init_method='env://') args.world_size = torch.distributed.get_world_size() args.total_batch_size = args.world_size * args.batch_size # Set the device device = torch.device('cpu' if args.cpu else 'cuda:' + str(args.gpu)) # create model if args.test: args.arch = 'ResNet10' if args.local_rank == 0: print("=> creating model '{}'".format(args.arch)) if args.arch == 'ResNet18': model = rn.ResNet18_Counter() elif args.arch == 'ResNet34': model = rn.ResNet34_Counter() elif args.arch == 'ResNet50': model = rn.ResNet50_Counter() elif args.arch == 'ResNet101': model = rn.ResNet101_Counter() elif args.arch == 'ResNet152': model = rn.ResNet152_Counter() elif args.arch == 'ResNet10': model = rn.ResNet10_Counter() else: print("Unrecognized {} architecture".format(args.arch)) model = model.to(device) # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel. if args.distributed: if args.cpu: model = DDP(model) else: model = DDP(model, device_ids=[args.gpu], output_device=args.gpu) if args.verbose: print( 'Since we are in a distributed setting the model is replicated here in local rank {}' .format(args.local_rank)) # Set optimizer optimizer = Model_Util.get_optimizer(model, args) if args.local_rank == 0 and args.verbose: print('Optimizer used for this run is {}'.format(args.optimizer)) # Set learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lrsp, args.lrm) total_time = Utilities.AverageMeter() loss_history = [] counter_error_history = [] # Optionally resume from a checkpoint if args.resume: # Use a local scope to avoid dangling references def resume(): if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.cpu: checkpoint = torch.load(args.resume, map_location='cpu') else: checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu)) loss_history = checkpoint['loss_history'] counter_error_history = checkpoint['Counter_error_history'] start_epoch = checkpoint['epoch'] best_error = checkpoint['best_error'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) total_time = checkpoint['total_time'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) print("Model best precision saved was {}".format(best_error)) return start_epoch, best_error, model, optimizer, lr_scheduler, loss_history, counter_error_history, total_time else: print("=> no checkpoint found at '{}'".format(args.resume)) args.start_epoch, best_error, model, optimizer, lr_scheduler, loss_history, counter_error_history, total_time = resume( ) # Data loading code if len(args.data) == 1: traindir = os.path.join(args.data[0], 'train') valdir = os.path.join(args.data[0], 'val') else: traindir = args.data[0] valdir = args.data[1] if args.test: training_f = h5py.File(traindir + '/train_toy.h5', 'r') validation_f = h5py.File(valdir + '/validation_toy.h5', 'r') else: training_f = h5py.File(traindir + '/train.h5', 'r') validation_f = h5py.File(valdir + '/validation.h5', 'r') # this is the dataset for training sampling_rate = 10000 # This is the number of samples per second of the signals in the dataset if args.test: number_of_concentrations = 2 # This is the number of different concentrations in the dataset number_of_durations = 2 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 4 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 20 # This is the time of a complete signal for certain concentration and duration else: number_of_concentrations = 20 # This is the number of different concentrations in the dataset number_of_durations = 5 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 15 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 20 # This is the time of a complete signal for certain concentration and duration # Training Artificial Data Loader TADL = Artificial_DataLoader(args.world_size, args.local_rank, device, training_f, sampling_rate, number_of_concentrations, number_of_durations, number_of_diameters, window, length, args.batch_size) # this is the dataset for validating if args.test: number_of_concentrations = 2 # This is the number of different concentrations in the dataset number_of_durations = 2 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 4 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration else: number_of_concentrations = 20 # This is the number of different concentrations in the dataset number_of_durations = 5 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 15 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration # Validating Artificial Data Loader VADL = Artificial_DataLoader(args.world_size, args.local_rank, device, validation_f, sampling_rate, number_of_concentrations, number_of_durations, number_of_diameters, window, length, args.batch_size) if args.verbose: print('From rank {} training shard size is {}'.format( args.local_rank, TADL.get_number_of_avail_windows())) print('From rank {} validation shard size is {}'.format( args.local_rank, VADL.get_number_of_avail_windows())) if args.run: arguments = { 'model': model, 'device': device, 'epoch': 0, 'VADL': VADL } if args.local_rank == 0: run_model(args, arguments) return if args.statistics: arguments = { 'model': model, 'device': device, 'epoch': 0, 'VADL': VADL } counter_errors = compute_error_stats(args, arguments) if args.local_rank == 0: plot_stats(VADL, counter_errors) return if args.evaluate: arguments = { 'model': model, 'device': device, 'epoch': 0, 'VADL': VADL } counter_error = validate(args, arguments) print('##Counter error {0}'.format(counter_error)) return if args.plot_training_history and args.local_rank == 0: Model_Util.plot_counter_stats(loss_history, counter_error_history) hours = int(total_time.sum / 3600) minutes = int((total_time.sum % 3600) / 60) seconds = int((total_time.sum % 3600) % 60) print('The total training time was {} hours {} minutes and {} seconds'. format(hours, minutes, seconds)) hours = int(total_time.avg / 3600) minutes = int((total_time.avg % 3600) / 60) seconds = int((total_time.avg % 3600) % 60) print( 'while the average time during one epoch of training was {} hours {} minutes and {} seconds' .format(hours, minutes, seconds)) return for epoch in range(args.start_epoch, args.epochs): arguments = { 'model': model, 'optimizer': optimizer, 'device': device, 'epoch': epoch, 'TADL': TADL, 'VADL': VADL, 'loss_history': loss_history, 'counter_error_history': counter_error_history } # train for one epoch epoch_time, avg_batch_time = train(args, arguments) total_time.update(epoch_time) # evaluate on validation set counter_error = validate(args, arguments) error = counter_error #if args.test: #break lr_scheduler.step() # remember the best model and save checkpoint if args.local_rank == 0: print('From validation we have error is {} while best_error is {}'. format(error, best_error)) is_best = error < best_error best_error = min(error, best_error) Model_Util.save_checkpoint( { 'arch': args.arch, 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_error': best_error, 'optimizer': optimizer.state_dict(), 'loss_history': loss_history, 'Counter_error_history': counter_error_history, 'lr_scheduler': lr_scheduler.state_dict(), 'total_time': total_time }, is_best) print('##Counter error {0}\n' '##Perf {1}'.format(counter_error, args.total_batch_size / avg_batch_time))
def validate(args, arguments): batch_time = Utilities.AverageMeter() average_duration_error = Utilities.AverageMeter() average_amplitude_error = Utilities.AverageMeter() # switch to evaluate mode arguments['model'].eval() end = time.time() val_loader_len = int( math.ceil(arguments['VADL'].shard_size / args.batch_size)) i = 0 arguments['VADL'].reset_avail_winds(arguments['epoch']) while i * arguments['VADL'].batch_size < arguments['VADL'].shard_size: # bring a new batch times, noisy_signals, clean_signals, _, labels = arguments[ 'VADL'].get_batch() mean = torch.mean(noisy_signals, 1, True) noisy_signals = noisy_signals - mean with torch.no_grad(): noisy_signals = noisy_signals.unsqueeze(1) external = torch.reshape(labels[:, 0], [arguments['VADL'].batch_size, 1]) outputs = arguments['model'](noisy_signals, external) noisy_signals = noisy_signals.squeeze(1) errors = abs( (labels[:, 1:].to('cpu') - outputs.data.to('cpu') * torch.Tensor([10**(-3), 10** (-10)]).repeat(arguments['VADL'].batch_size, 1)) / labels[:, 1:].to('cpu')) * 100 errors = torch.mean(errors, dim=0) duration_error = errors[0] amplitude_error = errors[1] if args.distributed: reduced_duration_error = Utilities.reduce_tensor( duration_error.data, args.world_size) reduced_amplitude_error = Utilities.reduce_tensor( amplitude_error.data, args.world_size) else: reduced_duration_error = duration_error.data reduced_amplitude_error = amplitude_error.data average_duration_error.update( Utilities.to_python_float(reduced_duration_error), args.batch_size) average_amplitude_error.update( Utilities.to_python_float(reduced_amplitude_error), args.batch_size) # measure elapsed time batch_time.update(time.time() - end) end = time.time() #if args.test: #if i > 10: #break if args.local_rank == 0 and i % args.print_freq == 0 and i != 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {2:.3f} ({3:.3f})\t' 'Duration Error {dur_error.val:.4f} ({dur_error.avg:.4f})\t' 'Amplitude Error {amp_error.val:.4f} ({amp_error.avg:.4f})'. format(i, val_loader_len, args.world_size * args.batch_size / batch_time.val, args.world_size * args.batch_size / batch_time.avg, batch_time=batch_time, dur_error=average_duration_error, amp_error=average_amplitude_error)) i += 1 if not args.evaluate: arguments['duration_error_history'].append(average_duration_error.avg) arguments['amplitude_error_history'].append( average_amplitude_error.avg) return [average_duration_error.avg, average_amplitude_error.avg]
def main(): global best_error, args best_error = math.inf args = parse() if not len(args.data): raise Exception("error: No data set provided") if not len(args.counter): raise Exception("error: No path to counter model provided") if not len(args.predictor): raise Exception("error: No path to predictor model provided") args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank if not args.cpu: torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='gloo', init_method='env://') args.world_size = torch.distributed.get_world_size() args.total_batch_size = args.world_size * args.batch_size # Set the device device = torch.device('cpu' if args.cpu else 'cuda:' + str(args.gpu)) # create model_1 if args.test: args.arch_1 = 'ResNet10' if args.local_rank == 0: print("=> creating model_1 '{}'".format(args.arch_1)) if args.arch_1 == 'ResNet18': model_1 = rn.ResNet18_Counter() elif args.arch_1 == 'ResNet34': model_1 = rn.ResNet34_Counter() elif args.arch_1 == 'ResNet50': model_1 = rn.ResNet50_Counter() elif args.arch_1 == 'ResNet101': model_1 = rn.ResNet101_Counter() elif args.arch_1 == 'ResNet152': model_1 = rn.ResNet152_Counter() elif args.arch_1 == 'ResNet10': model_1 = rn.ResNet10_Counter() else: print("Unrecognized {} for translocations counter architecture".format( args.arch_1)) # create model_2 if args.test: args.arch_2 = 'ResNet10' if args.local_rank == 0: print("=> creating model_2 '{}'".format(args.arch_2)) if args.arch_2 == 'ResNet18': model_2 = rn.ResNet18_Custom() elif args.arch_2 == 'ResNet34': model_2 = rn.ResNet34_Custom() elif args.arch_2 == 'ResNet50': model_2 = rn.ResNet50_Custom() elif args.arch_2 == 'ResNet101': model_2 = rn.ResNet101_Custom() elif args.arch_2 == 'ResNet152': model_2 = rn.ResNet152_Custom() elif args.arch_2 == 'ResNet10': model_2 = rn.ResNet10_Custom() else: print( "Unrecognized {} for translocation feature prediction architecture" .format(args.arch_2)) model_1 = model_1.to(device) model_2 = model_2.to(device) # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel. if args.distributed: if args.cpu: model_1 = DDP(model_1) model_2 = DDP(model_2) else: model_1 = DDP(model_1, device_ids=[args.gpu], output_device=args.gpu) model_2 = DDP(model_2, device_ids=[args.gpu], output_device=args.gpu) if args.verbose: print( 'Since we are in a distributed setting the model is replicated here in local rank {}' .format(args.local_rank)) total_time = Utilities.AverageMeter() # bring counter from a checkpoint if args.counter: # Use a local scope to avoid dangling references def bring_counter(): if os.path.isfile(args.counter): print("=> loading counter '{}'".format(args.counter)) if args.cpu: checkpoint = torch.load(args.counter, map_location='cpu') else: checkpoint = torch.load(args.counter, map_location=lambda storage, loc: storage.cuda(args.gpu)) loss_history_1 = checkpoint['loss_history'] counter_error_history = checkpoint['Counter_error_history'] best_error_1 = checkpoint['best_error'] model_1.load_state_dict(checkpoint['state_dict']) total_time_1 = checkpoint['total_time'] print("=> loaded counter '{}' (epoch {})".format( args.counter, checkpoint['epoch'])) print("Model best precision saved was {}".format(best_error_1)) return best_error_1, model_1, loss_history_1, counter_error_history, total_time_1 else: print("=> no counter found at '{}'".format(args.counter)) best_error_1, model_1, loss_history_1, counter_error_history, total_time_1 = bring_counter( ) else: raise Exception("error: No counter path provided") # bring predictor from a checkpoint if args.predictor: # Use a local scope to avoid dangling references def bring_predictor(): if os.path.isfile(args.predictor): print("=> loading predictor '{}'".format(args.predictor)) if args.cpu: checkpoint = torch.load(args.predictor, map_location='cpu') else: checkpoint = torch.load(args.predictor, map_location=lambda storage, loc: storage.cuda(args.gpu)) loss_history_2 = checkpoint['loss_history'] duration_error_history = checkpoint['duration_error_history'] amplitude_error_history = checkpoint['amplitude_error_history'] best_error_2 = checkpoint['best_error'] model_2.load_state_dict(checkpoint['state_dict']) total_time_2 = checkpoint['total_time'] print("=> loaded predictor '{}' (epoch {})".format( args.predictor, checkpoint['epoch'])) print("Model best precision saved was {}".format(best_error_2)) return best_error_2, model_2, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 else: print("=> no predictor found at '{}'".format(args.predictor)) best_error_2, model_2, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 = bring_predictor( ) else: raise Exception("error: No predictor path provided") # plots validation stats from a file if args.stats_from_file and args.local_rank == 0: # Use a local scope to avoid dangling references def bring_stats_from_file(): if os.path.isfile(args.stats_from_file): print("=> loading stats from file '{}'".format( args.stats_from_file)) if args.cpu: stats = torch.load(args.stats_from_file, map_location='cpu') else: stats = torch.load(args.stats_from_file, map_location=lambda storage, loc: storage.cuda(args.gpu)) count_translocations = stats['count_translocations'] duration_translocations = stats['duration_translocations'] amplitude_translocations = stats['amplitude_translocations'] Trace = stats['Trace'] Arch = stats['Arch'] print("=> loaded stats '{}'".format(args.stats_from_file)) return count_translocations, duration_translocations, amplitude_translocations, Trace, Arch else: print("=> no stats found at '{}'".format(args.stats_from_file)) count_translocations, duration_translocations, amplitude_translocations, Trace, Arch = bring_stats_from_file( ) plot_stats(Trace, count_translocations, duration_translocations, amplitude_translocations) return # Data loading code testdir = args.data if args.test: test_f = h5py.File(testdir + '/test_toy.h5', 'r') else: test_f = h5py.File(testdir + '/test.h5', 'r') # this is the dataset for validating if args.test: num_of_traces = 2 # This is the number of different traces in the dataset window = 0.5 # This is the time window in seconds length = args.trace_length # This is the time of a complete signal for certain concentration and duration else: num_of_traces = 6 # This is the number of different traces in the dataset window = 0.5 # This is the time window in seconds length = args.trace_length # This is the time of a complete signal for certain concentration and duration # Validating Artificial Data Loader TRDL = Unlabeled_Real_DataLoader(device, test_f, num_of_traces, window, length) if args.run: arguments = { 'model_1': model_1, 'model_2': model_2, 'device': device, 'TRDL': TRDL } if args.local_rank == 0: run_model(args, arguments) return if args.statistics: arguments = { 'model_1': model_1, 'model_2': model_2, 'device': device, 'TRDL': TRDL } [ count_translocations, duration_translocations, amplitude_translocations ] = compute_value_stats(args, arguments) if args.local_rank == 0: Trace = TRDL.shape[0] plot_stats(Trace, count_translocations, duration_translocations, amplitude_translocations) if args.save_stats: Model_Util.save_stats( { 'count_translocations': count_translocations, 'duration_translocations': duration_translocations, 'amplitude_translocations': amplitude_translocations, 'Trace': TRDL.shape[0], 'Arch': args.arch_2 }, args.save_stats) return
def validate(args, arguments): average_precision = Utilities.AverageMeter() # switch to evaluate mode arguments['detr'].eval() end = time.time() val_loader_len = int(math.ceil(arguments['VADL'].shard_size / args.batch_size)) i = 0 arguments['VADL'].reset_avail_winds(arguments['epoch']) pred_segments = [] true_segments = [] while i * arguments['VADL'].batch_size < arguments['VADL'].shard_size: # get the noisy inputs and the labels _, inputs, _, targets, labels = arguments['TADL'].get_batch() mean = torch.mean(inputs, 1, True) inputs = inputs-mean with torch.no_grad(): # forward inputs = inputs.unsqueeze(1) outputs = arguments['detr'](inputs) for j in range(arguments['VADL'].batch_size): train_idx = int(j + i * arguments['VADL'].batch_size) probabilities = F.softmax(outputs['pred_logits'][j], dim=1) aux_pred_segments = outputs['pred_segments'][j] for probability, pred_segment in zip(probabilities.to('cpu'), aux_pred_segments.to('cpu')): #if probability[-1] < 0.9: if torch.argmax(probability) != args.num_classes: segment = [train_idx, np.argmax(probability[:-1]).item(), 1.0 - probability[-1].item(), pred_segment[0].item(), pred_segment[1].item()] pred_segments.append(segment) num_pulses = labels[j, 0] starts = targets[j, 0] widths = targets[j, 1] categories = targets[j, 3] for k in range(int(num_pulses.item())): segment = [train_idx, categories[k].item(), 1.0, starts[k].item(), widths[k].item()] true_segments.append(segment) i += 1 for threshold in np.arange(0.5, 0.95, 0.05): detection_precision=mean_average_precision(device=arguments['device'], pred_segments=pred_segments, true_segments=true_segments, iou_threshold=threshold, seg_format="mix", num_classes=1) if args.distributed: reduced_detection_precision = Utilities.reduce_tensor(detection_precision.data, args.world_size) else: reduced_detection_precision = detection_precision.data average_precision.update(Utilities.to_python_float(reduced_detection_precision)) if not args.evaluate: arguments['precision_history'].append(average_precision.avg) return average_precision.avg
def train(args, arguments): batch_time = Utilities.AverageMeter() losses = Utilities.AverageMeter() # switch to train mode arguments['detr'].train() end = time.time() train_loader_len = int(math.ceil(arguments['TADL'].shard_size / args.batch_size)) i = 0 arguments['TADL'].reset_avail_winds(arguments['epoch']) ######################################################## #_, inputs, _, targets, _ = arguments['TADL'].get_batch() #targets = transform_targets(targets) #lr_scheduler = torch.optim.lr_scheduler.StepLR(arguments['optimizer'], args.lrsp, # args.lrm) ######################################################## ######################################################## #while True: ######################################################## while i * arguments['TADL'].batch_size < arguments['TADL'].shard_size: # get the noisy inputs and the labels _, inputs, _, targets, _ = arguments['TADL'].get_batch() mean = torch.mean(inputs, 1, True) inputs = inputs-mean # zero the parameter gradients arguments['optimizer'].zero_grad() # forward + backward + optimize inputs = inputs.unsqueeze(1) outputs = arguments['detr'](inputs) ######################################################## #inputs = inputs.squeeze(1) ######################################################## # Compute the loss targets = transform_targets(targets) loss_dict = arguments['criterion'].forward(outputs=outputs, targets=targets) weight_dict = arguments['criterion'].weight_dict loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # compute gradient and do optimizer step loss.backward() torch.nn.utils.clip_grad_norm_(arguments['detr'].parameters(), 0.1) arguments['optimizer'].step() #if args.test: #if i > 10: #break if i%args.print_freq == 0: #if i%args.print_freq == 0 and i != 0: # Every print_freq iterations, check the loss and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Average loss across processes for logging if args.distributed: reduced_loss = Utilities.reduce_tensor(loss.data, args.world_size) else: reduced_loss = loss.data # to_python_float incurs a host<->device sync losses.update(Utilities.to_python_float(reduced_loss), args.batch_size) if not args.cpu: torch.cuda.synchronize() batch_time.update((time.time() - end)/args.print_freq, args.print_freq) end = time.time() if args.local_rank == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})'.format( arguments['epoch'], i, train_loader_len, args.world_size*args.batch_size/batch_time.val, args.world_size*args.batch_size/batch_time.avg, batch_time=batch_time, loss=losses)) i += 1 ######################################################## #lr_scheduler.step() ######################################################## arguments['loss_history'].append(losses.avg) return batch_time.sum, batch_time.avg
def main(): global best_precision, args best_precision = 0 args = parse() if not len(args.data): raise Exception("error: No data set provided") args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank if not args.cpu: torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='gloo', init_method='env://') args.world_size = torch.distributed.get_world_size() args.total_batch_size = args.world_size * args.batch_size # Set the device device = torch.device('cpu' if args.cpu else 'cuda:' + str(args.gpu)) ####################################################################### # Start DETR contruction ####################################################################### # create DETR backbone # create backbone pulse counter if args.test: args.pulse_counter_arch = 'ResNet10' if args.local_rank==0 and args.verbose: print("=> creating backbone pulse counter '{}'".format(args.pulse_counter_arch)) if args.pulse_counter_arch == 'ResNet18': backbone_pulse_counter = rn.ResNet18_Counter() elif args.pulse_counter_arch == 'ResNet34': backbone_pulse_counter = rn.ResNet34_Counter() elif args.pulse_counter_arch == 'ResNet50': backbone_pulse_counter = rn.ResNet50_Counter() elif args.pulse_counter_arch == 'ResNet101': backbone_pulse_counter = rn.ResNet101_Counter() elif args.pulse_counter_arch == 'ResNet152': backbone_pulse_counter = rn.ResNet152_Counter() elif args.pulse_counter_arch == 'ResNet10': backbone_pulse_counter = rn.ResNet10_Counter() else: print("Unrecognized {} architecture for the backbone pulse counter" .format(args.pulse_counter_arch)) backbone_pulse_counter = backbone_pulse_counter.to(device) # create backbone feature predictor if args.test: args.feature_predictor_arch = 'ResNet10' if args.local_rank==0 and args.verbose: print("=> creating backbone feature predictor '{}'".format(args.feature_predictor_arch)) if args.feature_predictor_arch == 'ResNet18': backbone_feature_predictor = rn.ResNet18_Custom() elif args.feature_predictor_arch == 'ResNet34': backbone_feature_predictor = rn.ResNet34_Custom() elif args.feature_predictor_arch == 'ResNet50': backbone_feature_predictor = rn.ResNet50_Custom() elif args.feature_predictor_arch == 'ResNet101': backbone_feature_predictor = rn.ResNet101_Custom() elif args.feature_predictor_arch == 'ResNet152': backbone_feature_predictor = rn.ResNet152_Custom() elif args.feature_predictor_arch == 'ResNet10': backbone_feature_predictor = rn.ResNet10_Custom() else: print("Unrecognized {} architecture for the backbone feature predictor" .format(args.feature_predictor_arch)) backbone_feature_predictor = backbone_feature_predictor.to(device) # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel. if args.distributed: if args.cpu: backbone_pulse_counter = DDP(backbone_pulse_counter) backbone_feature_predictor = DDP(backbone_feature_predictor) else: backbone_pulse_counter = DDP(backbone_pulse_counter, device_ids=[args.gpu], output_device=args.gpu) backbone_feature_predictor = DDP(backbone_feature_predictor, device_ids=[args.gpu], output_device=args.gpu) if args.verbose: print('Since we are in a distributed setting the backbone componets are replicated here in local rank {}' .format(args.local_rank)) # bring counter from a checkpoint if args.counter: # Use a local scope to avoid dangling references def bring_counter(): if os.path.isfile(args.counter): print("=> loading backbone pulse counter '{}'" .format(args.counter)) if args.cpu: checkpoint = torch.load(args.counter, map_location='cpu') else: checkpoint = torch.load(args.counter, map_location = lambda storage, loc: storage.cuda(args.gpu)) loss_history_1 = checkpoint['loss_history'] counter_error_history = checkpoint['Counter_error_history'] best_error_1 = checkpoint['best_error'] backbone_pulse_counter.load_state_dict(checkpoint['state_dict']) total_time_1 = checkpoint['total_time'] print("=> loaded counter '{}' (epoch {})" .format(args.counter, checkpoint['epoch'])) print("Counter best precision saved was {}" .format(best_error_1)) return best_error_1, backbone_pulse_counter, loss_history_1, counter_error_history, total_time_1 else: print("=> no counter found at '{}'" .format(args.counter)) best_error_1, backbone_pulse_counter, loss_history_1, counter_error_history, total_time_1 = bring_counter() else: raise Exception("error: No counter path provided") # bring predictor from a checkpoint if args.predictor: # Use a local scope to avoid dangling references def bring_predictor(): if os.path.isfile(args.predictor): print("=> loading backbone feature predictor '{}'" .format(args.predictor)) if args.cpu: checkpoint = torch.load(args.predictor, map_location='cpu') else: checkpoint = torch.load(args.predictor, map_location = lambda storage, loc: storage.cuda(args.gpu)) loss_history_2 = checkpoint['loss_history'] duration_error_history = checkpoint['duration_error_history'] amplitude_error_history = checkpoint['amplitude_error_history'] best_error_2 = checkpoint['best_error'] backbone_feature_predictor.load_state_dict(checkpoint['state_dict']) total_time_2 = checkpoint['total_time'] print("=> loaded predictor '{}' (epoch {})" .format(args.predictor, checkpoint['epoch'])) print("Predictor best precision saved was {}" .format(best_error_2)) return best_error_2, backbone_feature_predictor, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 else: print("=> no predictor found at '{}'" .format(args.predictor)) best_error_2, backbone_feature_predictor, loss_history_2, duration_error_history, amplitude_error_history, total_time_2 = bring_predictor() else: raise Exception("error: No predictor path provided") # create backbone if args.local_rank==0 and args.verbose: print("=> creating backbone") if args.feature_predictor_arch == 'ResNet18': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=512) elif args.feature_predictor_arch == 'ResNet34': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=512) elif args.feature_predictor_arch == 'ResNet50': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=2048) elif args.feature_predictor_arch == 'ResNet101': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=2048) elif args.feature_predictor_arch == 'ResNet152': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=2048) elif args.feature_predictor_arch == 'ResNet10': backbone=build_backbone(pulse_counter=backbone_pulse_counter, feature_predictor=backbone_feature_predictor, num_channels=512) else: print("Unrecognized {} architecture for the backbone feature predictor" .format(args.feature_predictor_arch)) backbone = backbone.to(device) # create DETR transformer if args.local_rank==0 and args.verbose: print("=> creating transformer") if args.test: args.transformer_hidden_dim = 64 args.transformer_num_heads = 2 args.transformer_dim_feedforward = 256 args.transformer_num_enc_layers = 2 args.transformer_num_dec_layers = 2 args.transformer_pre_norm = True transformer = build_transformer(hidden_dim=args.transformer_hidden_dim, dropout=args.transformer_dropout, nheads=args.transformer_num_heads, dim_feedforward=args.transformer_dim_feedforward, enc_layers=args.transformer_num_enc_layers, dec_layers=args.transformer_num_dec_layers, pre_norm=args.transformer_pre_norm) # create DETR in itself if args.local_rank==0 and args.verbose: print("=> creating DETR") detr = DT.DETR(backbone=backbone, transformer=transformer, num_classes=args.num_classes, num_queries=args.num_queries) detr = detr.to(device) # For distributed training, wrap the model with torch.nn.parallel.DistributedDataParallel. if args.distributed: if args.cpu: detr = DDP(detr) else: detr = DDP(detr, device_ids=[args.gpu], output_device=args.gpu) if args.verbose: print('Since we are in a distributed setting DETR model is replicated here in local rank {}' .format(args.local_rank)) # Set matcher if args.local_rank==0 and args.verbose: print("=> set Hungarian Matcher") matcher = mtchr.HungarianMatcher(cost_class=args.cost_class, cost_bsegment=args.cost_bsegment, cost_giou=args.cost_giou) # Set criterion if args.local_rank==0 and args.verbose: print("=> set criterion for the loss") weight_dict = {'loss_ce': args.loss_ce, 'loss_bsegment': args.loss_bsegment, 'loss_giou': args.loss_giou} losses = ['labels', 'segments', 'cardinality'] criterion = DT.SetCriterion(num_classes=args.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=args.eos_coef, losses=losses) criterion = criterion.to(device) # Set optimizer optimizer = Model_Util.get_DETR_optimizer(detr, args) if args.local_rank==0 and args.verbose: print('Optimizer used for this run is {}'.format(args.optimizer)) # Set learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lrsp, args.lrm) total_time = Utilities.AverageMeter() loss_history = [] precision_history = [] # Optionally resume from a checkpoint if args.resume: # Use a local scope to avoid dangling references def resume(): if os.path.isfile(args.resume): print("=> loading checkpoint '{}'" .format(args.resume)) if args.cpu: checkpoint = torch.load(args.resume, map_location='cpu') else: checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) loss_history = checkpoint['loss_history'] precision_history = checkpoint['precision_history'] start_epoch = checkpoint['epoch'] best_precision = checkpoint['best_precision'] detr.load_state_dict(checkpoint['state_dict']) criterion.load_state_dict(checkpoint['criterion']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) total_time = checkpoint['total_time'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) return start_epoch, detr, criterion, optimizer, lr_scheduler, loss_history, precision_history, total_time, best_precision else: print("=> no checkpoint found at '{}'" .format(args.resume)) args.start_epoch, detr, criterion, optimizer, lr_scheduler, loss_history, precision_history, total_time, best_precision = resume() # Data loading code if len(args.data) == 1: traindir = os.path.join(args.data[0], 'train') valdir = os.path.join(args.data[0], 'val') else: traindir = args.data[0] valdir= args.data[1] if args.test: training_f = h5py.File(traindir + '/train_toy.h5', 'r') validation_f = h5py.File(valdir + '/validation_toy.h5', 'r') else: training_f = h5py.File(traindir + '/train.h5', 'r') validation_f = h5py.File(valdir + '/validation.h5', 'r') # this is the dataset for training sampling_rate = 10000 # This is the number of samples per second of the signals in the dataset if args.test: number_of_concentrations = 2 # This is the number of different concentrations in the dataset number_of_durations = 2 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 4 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 20 # This is the time of a complete signal for certain concentration and duration else: number_of_concentrations = 20 # This is the number of different concentrations in the dataset number_of_durations = 5 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 15 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 20 # This is the time of a complete signal for certain concentration and duration # Training Artificial Data Loader TADL = Artificial_DataLoader(args.world_size, args.local_rank, device, training_f, sampling_rate, number_of_concentrations, number_of_durations, number_of_diameters, window, length, args.batch_size) # this is the dataset for validating if args.test: number_of_concentrations = 2 # This is the number of different concentrations in the dataset number_of_durations = 2 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 4 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration else: number_of_concentrations = 20 # This is the number of different concentrations in the dataset number_of_durations = 5 # This is the number of different translocation durations per concentration in the dataset number_of_diameters = 15 # This is the number of different translocation durations per concentration in the dataset window = 0.5 # This is the time window in seconds length = 10 # This is the time of a complete signal for certain concentration and duration # Validating Artificial Data Loader VADL = Artificial_DataLoader(args.world_size, args.local_rank, device, validation_f, sampling_rate, number_of_concentrations, number_of_durations, number_of_diameters, window, length, args.batch_size) if args.verbose: print('From rank {} training shard size is {}'. format(args.local_rank, TADL.get_number_of_avail_windows())) print('From rank {} validation shard size is {}'. format(args.local_rank, VADL.get_number_of_avail_windows())) if args.run: arguments = {'model': detr, 'device': device, 'epoch': 0, 'VADL': VADL} if args.local_rank == 0: run_model(args, arguments) return #if args.statistics: #arguments = {'model': model, #'device': device, #'epoch': 0, #'VADL': VADL} #[duration_errors, amplitude_errors] = compute_error_stats(args, arguments) #if args.local_rank == 0: #plot_stats(VADL, duration_errors, amplitude_errors) #return #if args.evaluate: #arguments = {'model': model, #'device': device, #'epoch': 0, #'VADL': VADL} #[duration_error, amplitude_error] = validate(args, arguments) #print('##Duration error {0}\n' #'##Amplitude error {1}'.format( #duration_error, #amplitude_error)) #return if args.plot_training_history and args.local_rank == 0: Model_Util.plot_detector_stats(loss_history, precision_history) hours = int(total_time.sum / 3600) minutes = int((total_time.sum % 3600) / 60) seconds = int((total_time.sum % 3600) % 60) print('The total training time was {} hours {} minutes and {} seconds' .format(hours, minutes, seconds)) hours = int(total_time.avg / 3600) minutes = int((total_time.avg % 3600) / 60) seconds = int((total_time.avg % 3600) % 60) print('while the average time during one epoch of training was {} hours {} minutes and {} seconds' .format(hours, minutes, seconds)) return for epoch in range(args.start_epoch, args.epochs): arguments = {'detr': detr, 'criterion': criterion, 'optimizer': optimizer, 'device': device, 'epoch': epoch, 'TADL': TADL, 'VADL': VADL, 'loss_history': loss_history, 'precision_history': precision_history} # train for one epoch epoch_time, avg_batch_time = train(args, arguments) total_time.update(epoch_time) # validate every val_freq epochs if epoch%args.val_freq == 0 and epoch != 0: # evaluate on validation set print("\nValidating ...\nComputing mean average precision (mAP) for epoch {}" .format(epoch)) precision = validate(args, arguments) else: precision = best_precision #if args.test: #break lr_scheduler.step() # remember the best detr and save checkpoint if args.local_rank == 0: if epoch%args.val_freq == 0: print('From validation we have precision is {} while best_precision is {}'.format(precision, best_precision)) is_best = precision > best_precision best_precision = max(precision, best_precision) Model_Util.save_checkpoint({ 'arch': 'DETR_' + args.feature_predictor_arch, 'epoch': epoch + 1, 'best_precision': best_precision, 'state_dict': detr.state_dict(), 'criterion': criterion.state_dict(), 'optimizer': optimizer.state_dict(), 'loss_history': loss_history, 'precision_history': precision_history, 'lr_scheduler': lr_scheduler.state_dict(), 'total_time': total_time }, is_best) print('##Detector precision {0}\n' '##Perf {1}'.format( precision, args.total_batch_size / avg_batch_time))