def distributed_predict(input, target, model, criterion, cnt): # Allows distributed prediction on uneven batches. Test set isn't always large enough for every GPU to get a batch batch_size = input.size(0) output = loss = corr1 = corr5 = valid_batches = 0 if batch_size: with torch.no_grad(): output = model(input) loss = criterion(output, target).data # measure accuracy and record loss valid_batches = 1 corr1, corr5 = correct(output.data, target, topk=(1, 5)) dist_validate_tensor[0] = batch_size dist_validate_tensor[1] = valid_batches dist_validate_tensor[2] = loss dist_validate_tensor[3] = corr1 dist_validate_tensor[4] = corr5 batch_total, valid_batches, reduced_loss, corr1, corr5 = bps.push_pull( dist_validate_tensor, average=False, name="distributed_validation_tensor") reduced_loss = reduced_loss / valid_batches top1 = corr1 * (100.0 / batch_total) top5 = corr5 * (100.0 / batch_total) return top1, top5, reduced_loss, batch_total
def metric_average(val, name): tensor = torch.tensor(val) if args.cuda: tensor = tensor.cuda() avg_tensor = bps.push_pull(tensor, name=name) return avg_tensor.item()
def train(trn_loader, model, criterion, optimizer, scheduler, epoch): net_meter = NetworkMeter() timer = TimeMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() for i, (input, target) in enumerate(trn_loader): if args.short_epoch and (i > 10): break batch_num = i + 1 timer.batch_start() scheduler.update_lr(epoch, i + 1, len(trn_loader)) # compute output output = model(input) loss = criterion(output, target) should_print = (batch_num % args.print_freq == 0) or (batch_num == len(trn_loader)) # compute gradient and do SGD step if args.fp16: loss = loss * args.loss_scale # zero_grad() and converting fp16/fp32 is handled in optimizer loss.backward() optimizer.step(wait_for_finish=should_print) loss = loss / args.loss_scale else: optimizer.zero_grad() loss.backward() optimizer.step() # Train batch done. Logging results timer.batch_end() if args.local_rank == 0 and should_print: corr1, corr5 = correct(output.data, target, topk=(1, 5)) reduced_loss, batch_total = to_python_float( loss.data), to_python_float(input.size(0)) if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch validate_tensor[0] = batch_total validate_tensor[1] = reduced_loss validate_tensor[2] = corr1 validate_tensor[3] = corr5 batch_total, reduced_loss, corr1, corr5 = bps.push_pull( validate_tensor, average=False, name="validation_tensor") batch_total = batch_total.cpu().numpy() reduced_loss = reduced_loss.cpu().numpy() corr1 = corr1.cpu().numpy() corr5 = corr5.cpu().numpy() reduced_loss = reduced_loss / bps.size() top1acc = to_python_float(corr1) * (100.0 / batch_total) top5acc = to_python_float(corr5) * (100.0 / batch_total) losses.update(reduced_loss, batch_total) top1.update(top1acc, batch_total) top5.update(top5acc, batch_total) tb.log_memory() tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0)) tb.log_trn_loss(losses.val, top1.val, top5.val) recv_gbit, transmit_gbit = net_meter.update_bandwidth() tb.log("sizes/batch_total", batch_total) tb.log('net/recv_gbit', recv_gbit) tb.log('net/transmit_gbit', transmit_gbit) output = ( f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t' f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t' f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t' f'BW {recv_gbit:.3f} {transmit_gbit:.3f}') log.verbose(output) tb.update_step_count(batch_total)