def validate(val_loader, model, criterion, epoch, start_time): timer = TimeMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() eval_start_time = time.time() for i, (input, target) in enumerate(val_loader): if args.short_epoch and (i > 10): break batch_num = i + 1 timer.batch_start() if args.distributed: top1acc, top5acc, loss, batch_total = distributed_predict( input, target, model, criterion) else: with torch.no_grad(): output = model(input) loss = criterion(output, target).data batch_total = input.size(0) top1acc, top5acc = accuracy(output.data, target, topk=(1, 5)) # Eval batch done. Logging results timer.batch_end() losses.update(to_python_float(loss), to_python_float(batch_total)) top1.update(to_python_float(top1acc), to_python_float(batch_total)) top5.update(to_python_float(top5acc), to_python_float(batch_total)) should_print = (batch_num % args.print_freq == 0) or (batch_num == len(val_loader)) if args.local_rank == 0 and should_print: output = ( f'Test: [{epoch}][{batch_num}/{len(val_loader)}]\t' f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})') log.verbose(output) tb.log_eval(top1.avg, top5.avg, time.time() - eval_start_time) tb.log('epoch', epoch) return top1.avg, top5.avg
def train(trn_loader, model, criterion, optimizer, scheduler, epoch): net_meter = NetworkMeter() timer = TimeMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() for i, (input, target) in enumerate(trn_loader): if args.short_epoch and (i > 10): break batch_num = i + 1 timer.batch_start() scheduler.update_lr(epoch, i + 1, len(trn_loader)) # compute output output = model(input) loss = criterion(output, target) # compute gradient and do SGD step if args.fp16: loss = loss * args.loss_scale model.zero_grad() loss.backward() model_grads_to_master_grads(model_params, master_params) for param in master_params: param.grad.data = param.grad.data / args.loss_scale optimizer.step() master_params_to_model_params(model_params, master_params) loss = loss / args.loss_scale else: optimizer.zero_grad() loss.backward() optimizer.step() # Train batch done. Logging results timer.batch_end() corr1, corr5 = correct(output.data, target, topk=(1, 5)) reduced_loss, batch_total = to_python_float( loss.data), to_python_float(input.size(0)) if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda() batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor( metrics).cpu().numpy() reduced_loss = reduced_loss / dist_utils.env_world_size() top1acc = to_python_float(corr1) * (100.0 / batch_total) top5acc = to_python_float(corr5) * (100.0 / batch_total) losses.update(reduced_loss, batch_total) top1.update(top1acc, batch_total) top5.update(top5acc, batch_total) should_print = (batch_num % args.print_freq == 0) or (batch_num == len(trn_loader)) if args.local_rank == 0 and should_print: tb.log_memory() tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0)) tb.log_trn_loss(losses.val, top1.val, top5.val) recv_gbit, transmit_gbit = net_meter.update_bandwidth() tb.log("sizes/batch_total", batch_total) tb.log('net/recv_gbit', recv_gbit) tb.log('net/transmit_gbit', transmit_gbit) output = ( f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t' f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t' f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t' f'BW {recv_gbit:.3f} {transmit_gbit:.3f}') log.verbose(output) tb.update_step_count(batch_total)
def train(trn_loader, model, criterion, optimizer, scheduler, epoch): net_meter = NetworkMeter() timer = TimeMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() for i, (input, target) in enumerate(trn_loader): if args.short_epoch and (i > 10): break batch_num = i + 1 timer.batch_start() scheduler.update_lr(epoch, i + 1, len(trn_loader)) # compute output output = model(input) loss = criterion(output, target) should_print = (batch_num % args.print_freq == 0) or (batch_num == len(trn_loader)) # compute gradient and do SGD step if args.fp16: loss = loss * args.loss_scale # zero_grad() and converting fp16/fp32 is handled in optimizer loss.backward() optimizer.step(wait_for_finish=should_print) loss = loss / args.loss_scale else: optimizer.zero_grad() loss.backward() optimizer.step() # Train batch done. Logging results timer.batch_end() if args.local_rank == 0 and should_print: corr1, corr5 = correct(output.data, target, topk=(1, 5)) reduced_loss, batch_total = to_python_float( loss.data), to_python_float(input.size(0)) if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch validate_tensor[0] = batch_total validate_tensor[1] = reduced_loss validate_tensor[2] = corr1 validate_tensor[3] = corr5 batch_total, reduced_loss, corr1, corr5 = bps.push_pull( validate_tensor, average=False, name="validation_tensor") batch_total = batch_total.cpu().numpy() reduced_loss = reduced_loss.cpu().numpy() corr1 = corr1.cpu().numpy() corr5 = corr5.cpu().numpy() reduced_loss = reduced_loss / bps.size() top1acc = to_python_float(corr1) * (100.0 / batch_total) top5acc = to_python_float(corr5) * (100.0 / batch_total) losses.update(reduced_loss, batch_total) top1.update(top1acc, batch_total) top5.update(top5acc, batch_total) tb.log_memory() tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0)) tb.log_trn_loss(losses.val, top1.val, top5.val) recv_gbit, transmit_gbit = net_meter.update_bandwidth() tb.log("sizes/batch_total", batch_total) tb.log('net/recv_gbit', recv_gbit) tb.log('net/transmit_gbit', transmit_gbit) output = ( f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t' f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t' f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t' f'BW {recv_gbit:.3f} {transmit_gbit:.3f}') log.verbose(output) tb.update_step_count(batch_total)