def save_summary(epoch: int, global_step: int, accuracies: List[utils.AverageMeter], duration: timedelta, tracking_file: str, mode: str, top=(1, )): result: Dict[str, Any] = OrderedDict() result['timestamp'] = datetime.now() result['mode'] = mode result['epoch'] = epoch result['global_step'] = global_step result['duration'] = duration for k, acc in zip(top, accuracies): result[f'top{k}_accuracy'] = acc.avg utils.save_result(result, tracking_file)
test_loader=test_loader, fp16=proxy_fp16, run_dir=proxy_run_dir, checkpoint=checkpoint, batch_callback=batch_callback) # For analysis, record details about training the proxy. proxy_stats: Dict[str, Any] = OrderedDict() proxy_stats['nexamples'] = len(train_indices) proxy_stats['train_accuracy'] = proxy_accuracies.train proxy_stats['dev_accuracy'] = proxy_accuracies.dev proxy_stats['test_accuracy'] = proxy_accuracies.test proxy_stats['train_time'] = proxy_times.train proxy_stats['dev_time'] = proxy_times.dev proxy_stats['test_time'] = proxy_times.test utils.save_result(proxy_stats, os.path.join(run_dir, "proxy.csv")) current = np.array([], dtype=np.int64) # Create initial random subset for greedy k-center method. # Everything else can start with an empty selected subset. if selection_method == 'kcenters': assert subset > 1_000 # TODO: Maybe this shouldn't be hardcoded current = np.random.permutation(train_indices)[:1_000] nevents = None if selection_method == 'forgetting_events': nevents = forgetting_meter.nevents # Set the number of forgetting events for examples that the # model never got correct to infinity as in the original # paper.
executable, current_path, train_dataset, test_dataset, round_dir, tag, num_classes, lr=learning_rate, dim=dim, min_count=min_count, bucket=bucket, epoch=epochs, threads=threads, ngrams=ngrams, verbose=True) utils.save_result(train_stats, os.path.join(run_dir, "proxy.csv")) print('Selecting examples for size {}'.format(next_size)) # Rank examples based on the probabilities from fastText. ranking_start = datetime.now() ranking = calculate_rank(probs, selection_method) # Select top examples. if next_size > len(labeled): # Performing active learning. # Add top ranking examples to the existing labeled set. labeled_set = set(labeled) ranking = [i for i in ranking if i not in labeled_set] new_indices = ranking[:(next_size - len(labeled))] selection_stats['current_nexamples'] = len(labeled) selection_stats['new_nexamples'] = len(new_indices)
def run_epoch( epoch: int, global_step: int, model: nn.Module, loader: DataLoader, device: torch.device, criterion: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, top: Tuple[int, ...] = (1, ), output_file: Optional[str] = None, train: bool = True, label: Optional[str] = None, batch_callback: Optional[Callable] = None, fp16: bool = False) -> Tuple[int, List[utils.AverageMeter], timedelta]: """ Run a single epoch of train or validation Parameters ---------- epoch : int Current epoch of training global_step : int Current step in training (i.e., `epoch * (len(train_loader))`) model : nn.Module Pytorch model to train (must support residual outputs) loader : DataLoader Training or validation data device : torch.device Device to load inputs and targets criterion : nn.Module or None, default None Loss function to optimize for (training only) optimizer : Optimizer or None, default None Optimizer to optimize model (training only) top : Tuple[int] Specify points to calculate accuracies (e.g., top-1 & top-5 -> (1, 5)) output_file : str File path to log results train : bool Indicate whether to train the model and update weights label : str or None, default None Label for tqdm output batch_callback : Callable or None, default None Optional function to calculate stats on each batch. fp16 : bool, default False. Use mixed precision training. Returns ------- global_step : int accuracy : float Top-{top[0]} accuracy from the combined network """ if label is None: label = '(Train):' if train else '(Dev):' if train: assert criterion is not None, 'Need criterion to train model' _criterion = criterion assert optimizer is not None, 'Need optimizer to train model' _optimizer = optimizer losses = utils.AverageMeter() if fp16: from apex import amp # avoid dependency unless necessary. accuracies = [utils.AverageMeter() for _ in top] wrapped_loader = tqdm(loader) model.train(train) if output_file is not None: write_heading = not os.path.exists(output_file) with maybe_open(output_file) as out_file: with torch.set_grad_enabled(train): start = datetime.now() total_time = timedelta(0) for batch_index, ( indices, inputs, targets) in enumerate(wrapped_loader): # noqa: E501 batch_size = targets.size(0) assert batch_size < 2**32, 'Size is too large! correct will overflow' # noqa: E501 targets = targets.to(device) outputs = model(inputs) if batch_callback is not None: # Allow selection metrics like forgetting events batch_callback(indices, inputs, targets, outputs) if train: global_step += 1 loss = _criterion(outputs, targets) _optimizer.zero_grad() if fp16: with amp.scale_loss(loss, _optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() _optimizer.step() losses.update(loss.item(), batch_size) top_correct = utils.correct(outputs, targets, top=top) for i, count in enumerate(top_correct): accuracies[i].update(count.item() * (100. / batch_size), batch_size) # noqa: E501 end = datetime.now() # Don't count logging overhead duration = end - start total_time += duration if output_file is not None: result: Dict[str, Any] = OrderedDict() result['timestamp'] = datetime.now() result['batch_duration'] = duration result['global_step'] = global_step result['epoch'] = epoch result['batch'] = batch_index result['batch_size'] = batch_size for i, k in enumerate(top): result[f'top{k}_correct'] = top_correct[i].item() result[f'top{k}_accuracy'] = accuracies[i].val if train: result['loss'] = loss.item() utils.save_result(result, out_file, write_heading=write_heading) write_heading = False desc = 'Epoch {} {}'.format(epoch, label) if train: desc += ' Loss {loss.val:.4f} ({loss.avg:.4f})'.format( loss=losses) # noqa: E501 for k, acc in zip(top, accuracies): desc += ' Top-{} {acc.val:.3f} ({acc.avg:.3f})'.format( k, acc=acc) # noqa: E501 wrapped_loader.set_description(desc, refresh=False) start = datetime.now() return global_step, accuracies, total_time