def test(args, model, device, test_loader_creator, logger): model.eval() criterion = torch.nn.CrossEntropyLoss().to(device) with torch.no_grad(): losses = AverageMeter() acc = AverageMeter() for test_loader in test_loader_creator.data_loaders: for data, target in test_loader: data, target = data.to(device), target.to(device) _, output = model(data) loss = criterion(output, target) output = output.float() loss = loss.float() it_acc = accuracy(output.data, target)[0] losses.update(loss.item(), data.size(0)) acc.update(it_acc.item(), data.size(0)) logger.info('Test set: Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.avg:.3f}'.format(loss=losses, acc=acc))
def valid_func(xloader, network, criterion): data_time, batch_time = AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter( ), AverageMeter() network.eval() end = time.time() with torch.no_grad(): for step, (arch_inputs, arch_targets) in enumerate(xloader): arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # prediction _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() return arch_losses.avg, arch_top1.avg, arch_top5.avg
def train_shared_cnn(xloader, shared_cnn, criterion, scheduler, optimizer, print_freq, logger, config, start_epoch): # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter( ), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Traing the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(scheduler.get_lr()))) data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = AverageMeter(), AverageMeter( ), AverageMeter(), time.time() shared_cnn.train() for step, (inputs, targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) optimizer.zero_grad() _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*Train-Shared-CNN* ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=losses, top1=top1s, top5=top5s) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) cnn_loss, cnn_top1, cnn_top5 = losses.avg, top1s.avg, top5s.avg logger.log( '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) epoch_time.update(time.time() - start_time) start_time = time.time() return
def train_or_test_epoch(self, xloader, model, loss_fn, metric_fn, is_train, optimizer=None): if is_train: model.train() else: model.eval() score_meter, loss_meter = AverageMeter(), AverageMeter() for ibatch, (feats, labels) in enumerate(xloader): feats = feats.to(self.device, non_blocking=True) labels = labels.to(self.device, non_blocking=True) # forward the network preds = model(feats) loss = loss_fn(preds, labels) with torch.no_grad(): score = self.metric_fn(preds, labels) loss_meter.update(loss.item(), feats.size(0)) score_meter.update(score.item(), feats.size(0)) # optimize the network if is_train and optimizer is not None: optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) optimizer.step() return loss_meter.avg, score_meter.avg
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): data_time, batch_time, batch = AverageMeter(), AverageMeter(), None losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() latencies, device = [], torch.cuda.current_device() network.eval() with torch.no_grad(): end = time.time() for i, (inputs, targets) in enumerate(xloader): targets = targets.cuda(device=device, non_blocking=True) inputs = inputs.cuda(device=device, non_blocking=True) data_time.update(time.time() - end) # forward features, logits = network(inputs) loss = criterion(logits, targets) batch_time.update(time.time() - end) if batch is None or batch == inputs.size(0): batch = inputs.size(0) latencies.append(batch_time.val - data_time.val) # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) end = time.time() if len(latencies) > 2: latencies = latencies[1:] return losses.avg, top1.avg, top5.avg, latencies
def test_contrastive(args, model, nearest_proto_model, device, test_loader_creator_l, logger): model.eval() acc = AverageMeter() tasks_acc = [ AverageMeter() for i in range(len(test_loader_creator_l.data_loaders)) ] test_loaders_l = test_loader_creator_l.data_loaders with torch.no_grad(): for task_idx, test_loader_l in enumerate(test_loaders_l): for batch_idx, (data, _, target) in enumerate(test_loader_l): data, target = data.to(device), target.to(device) cur_feats, _ = model(data) output = nearest_proto_model.predict(cur_feats) it_acc = (output == target).sum().item() / data.shape[0] acc.update(it_acc, data.size(0)) tasks_acc[task_idx].update(it_acc, data.size(0)) if args.acc_per_task: tasks_acc_str = 'Tess Acc per task: ' for i, task_acc in enumerate(tasks_acc): tasks_acc_str += 'Task{:2d} Acc: {acc.avg:.3f}'.format( (i + 1), acc=task_acc) + '\t' logger.info(tasks_acc_str) logger.info('Test Acc: {acc.avg:.3f}'.format(acc=acc))
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() if mode == 'train': network.train() elif mode == 'valid': network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) device = torch.cuda.current_device() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() for i, (inputs, targets) in enumerate(xloader): if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) targets = targets.cuda(device=device, non_blocking=True) if mode == 'train': optimizer.zero_grad() # forward features, logits = network(inputs) loss = criterion(logits, targets) # backward if mode == 'train': loss.backward() optimizer.step() # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # count time batch_time.update(time.time() - end) end = time.time() return losses.avg, top1.avg, top5.avg, batch_time.sum
def train_bptt(num_epochs: int, model, dset_train, batch_size: int, T: int, w_checkpoint_freq: int, grad_clip: float, w_lr: float, logging_freq: int, sotl_order: int, hvp: str): model.train() train_loader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size * T, shuffle=True) for epoch in range(num_epochs): epoch_loss = AverageMeter() true_batch_index = 0 for batch_idx, batch in enumerate(train_loader): xs, ys = torch.split(batch[0], batch_size), torch.split( batch[1], batch_size) weight_buffer = WeightBuffer(T=T, checkpoint_freq=w_checkpoint_freq) for intra_batch_idx, (x, y) in enumerate(zip(xs, ys)): weight_buffer.add(model, intra_batch_idx) y_pred = model(x) loss = criterion(y_pred, y) epoch_loss.update(loss.item()) grads = torch.autograd.grad(loss, model.weight_params(), retain_graph=True, allow_unused=True, create_graph=True) w_optimizer.zero_grad() with torch.no_grad(): for g, w in zip(grads, model.weight_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.parameters(), 1) w_optimizer.step() true_batch_index += 1 if true_batch_index % logging_freq == 0: print("Epoch: {}, Batch: {}, Loss: {}".format( epoch, true_batch_index, epoch_loss.avg)) wandb.log({"Train loss": epoch_loss.avg}) total_arch_gradient = sotl_gradient(model, criterion, xs, ys, weight_buffer, w_lr=w_lr, hvp=hvp, order=sotl_order) a_optimizer.zero_grad() for g, w in zip(total_arch_gradient, model.arch_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.arch_params(), 1) a_optimizer.step()
def train_shared_cnn( xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger, ): data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = ( AverageMeter(), AverageMeter(), AverageMeter(), time.time(), ) shared_cnn.train() controller.eval() for step, (inputs, targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) with torch.no_grad(): _, _, sampled_arch = controller() optimizer.zero_grad() shared_cnn.module.update_arch(sampled_arch) _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = ( "*Train-Shared-CNN* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=losses, top1=top1s, top5=top5s) logger.log(Sstr + " " + Tstr + " " + Wstr) return losses.avg, top1s.avg, top5s.avg
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter( ), time.time() shared_cnn.train() controller.eval() ne = 10 for ni in range(ne): with torch.no_grad(): _, _, sampled_arch = controller() shared_cnn.module.update_arch(sampled_arch) print(sampled_arch) # arch_str = op_list2str(sampled_arch) for step, (inputs, targets) in enumerate(xloader): # print(step,inputs,targets) scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) optimizer.zero_grad() _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 2)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() # if step + 1 == len(xloader): Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:03d}/10]'.format( ni, ne) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = '[Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format( loss=losses, top1=top1s, top5=top5s) losses.reset() top1s.reset() top5s.reset() logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) return losses.avg, top1s.avg, top5s.avg
def valid_func(model, val_loader, criterion): model.eval() val_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch y_pred = model(x) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) print("Val loss: {}".format(val_meter.avg)) return val_meter
def eval_robust_heatmap(detector, xloader, print_freq, logger): batch_time, NUM_PTS = AverageMeter(), xloader.dataset.NUM_PTS Preds, GT_locs, Distances = [], [], [] eval_meta, end = Eval_Meta(), time.time() with torch.no_grad(): detector.eval() for i, (inputs, heatmaps, masks, norm_points, thetas, data_index, nopoints, xshapes) in enumerate(xloader): data_index = data_index.squeeze(1).tolist() batch_size, iters, C, H, W = inputs.size() for ibatch in range(batch_size): xinputs, xpoints, xthetas = inputs[ibatch], norm_points[ ibatch].permute(0, 2, 1).contiguous(), thetas[ibatch] batch_features, batch_heatmaps, batch_locs, batch_scos = detector( xinputs.cuda(non_blocking=True)) batch_locs = batch_locs.cpu()[:, :-1] all_locs = [] for _iter in range(iters): _locs = normalize_points((H, W), batch_locs[_iter].permute(1, 0)) xlocs = torch.cat((_locs, torch.ones(1, NUM_PTS)), dim=0) nlocs = torch.mm(xthetas[_iter, :2], xlocs) rlocs = denormalize_points(xshapes[ibatch].tolist(), nlocs) rlocs = torch.cat( (rlocs.permute(1, 0), xpoints[_iter, :, 2:]), dim=1) all_locs.append(rlocs.clone()) GT_loc = xloader.dataset.labels[ data_index[ibatch]].get_points() norm_distance = xloader.dataset.get_normalization_distance( data_index[ibatch]) # save the results eval_meta.append((sum(all_locs) / len(all_locs)).numpy().T, GT_loc.numpy(), xloader.dataset.datas[data_index[ibatch]], norm_distance) Distances.append(norm_distance) Preds.append(all_locs) GT_locs.append(GT_loc.permute(1, 0)) # compute time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 or i + 1 == len(xloader): last_time = convert_secs2time( batch_time.avg * (len(xloader) - i - 1), True) logger.log( ' -->>[Robust HEATMAP-based Evaluation] [{:03d}/{:03d}] Time : {:}' .format(i, len(xloader), last_time)) # evaluate the results errors, valids = calculate_robust(Preds, GT_locs, Distances, NUM_PTS) return errors, valids, eval_meta
def valid_func(model, dset_val, criterion, print_results=True): model.eval() val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32) val_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch y_pred = model(x) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) if print_results: print("Val loss: {}".format(val_meter.avg)) return val_meter
def search(self): self.eva_time = AverageMeter() init_start = time.time() self.init_random() self.logger.log('Initial_takes: %.2f' % (time.time() - init_start)) epoch_start_time = time.time() epoch_time_meter = AverageMeter() bests_per_epoch = list() perform_trace = list() for i in range(self.max_epochs): self.performances = torch.Tensor(self.performances) top_k = torch.argsort(self.performances, descending=True)[:self.parent_num] if self.best_perf is None or self.performances[ top_k[0]] > self.best_perf: self.best_cand = self.candidates[top_k[0]] self.best_perf = self.performances[top_k[0]] bests_per_epoch.append(self.best_cand) perform_trace.append(self.performances) self.parents = [] for idx in top_k: self.parents.append(self.candidates[idx]) self.candidates, self.performances = list(), list() self.eva_time = AverageMeter() self.get_mutation(self.population_num // 2) self.get_crossover() self.logger.log( '*SEARCH* ' + time_string() + '||| Epoch: %2d finished, %3d models have been tested, best performance is %.2f' % (i, len(self.perform_dict.keys()), self.best_perf)) self.logger.log(' - Best Cand: ' + str(self.best_cand)) this_epoch_time = time.time() - epoch_start_time epoch_time_meter.update(this_epoch_time) epoch_start_time = time.time() self.logger.log('Time for Epoch %d : %.2fs' % (i, this_epoch_time)) self.logger.log(' -- Evaluated %d models, with %.2f s in average' % (self.eva_time.count, self.eva_time.avg)) self.logger.log( '--------\nSearching Finished. Best Arch Found with Acc %.2f' % (self.best_perf)) self.logger.log(str(self.best_cand)) #torch.save(self.best_cand, self.save_dir+'/best_arch.pth') #torch.save(self.perform_dict, self.save_dir+'/perform_dict.pth') return bests_per_epoch, self.perform_dict, perform_trace
def train_normal(num_epochs, model, dset_train, batch_size, grad_clip, logging_freq, optim="sgd", **kwargs): train_loader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size, shuffle=True) model.train() for epoch in range(num_epochs): epoch_loss = AverageMeter() for batch_idx, batch in enumerate(train_loader): x, y = batch w_optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) loss.backward(retain_graph=True) epoch_loss.update(loss.item()) if optim == "newton": linear_weight = list(model.weight_params())[0] hessian_newton = torch.inverse( hessian(loss * 1, linear_weight, linear_weight).reshape(linear_weight.size()[1], linear_weight.size()[1])) with torch.no_grad(): for w in model.weight_params(): w = w.subtract_(torch.matmul(w.grad, hessian_newton)) elif optim == "sgd": torch.nn.utils.clip_grad_norm_(model.weight_params(), 1) w_optimizer.step() else: raise NotImplementedError wandb.log({ "Train loss": epoch_loss.avg, "Epoch": epoch, "Batch": batch_idx }) if batch_idx % logging_freq == 0: print("Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format( epoch, batch_idx, epoch_loss.avg, model.fc1.alphas.data))
def search_valid(xloader, network, criterion, extra_info, print_freq, logger): data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter() network.eval() network.apply(change_key('search_mode', 'search')) end = time.time() # logger.log('Starting evaluating {:}'.format(epoch_info)) with torch.no_grad(): for i, (inputs, targets) in enumerate(xloader): # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss targets = targets.cuda(non_blocking=True) logits, expected_flop = network(inputs) loss = criterion(logits, targets) # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 or (i + 1) == len(xloader): Sstr = '**VALID** ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format( loss=losses, top1=top1, top5=top5) Istr = 'Size={:}'.format(list(inputs.size())) logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) logger.log( ' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}' .format(top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg)) return losses.avg, top1.avg, top5.avg
def test_archi_acc(self, arch): if self.train_loader is not None: self.model.apply(ResetRunningStats) self.model.train() for step, (data, target) in enumerate(self.train_loader): # print('train step: {} total: {}'.format(step,max_train_iters)) # data, target = train_dataprovider.next() # print('get data',data.shape) #data = data.cuda() output = self.model.forward(data, arch) #_with_architect del data, target, output base_top1, base_top5 = AverageMeter(), AverageMeter() self.model.eval() one_batch = None for step, (data, target) in enumerate(self.val_loader): # print('test step: {} total: {}'.format(step,max_test_iters)) if one_batch == None: one_batch = data batchsize = data.shape[0] # print('get data',data.shape) target = target.cuda(non_blocking=True) #data, target = data.to(device), target.to(device) _, logits = self.model.forward(data, arch) #_with_architect prec1, prec5 = obtain_accuracy(logits.data, target.data, topk=(1, 5)) base_top1.update(prec1.item(), batchsize) base_top5.update(prec5.item(), batchsize) del data, target, logits, prec1, prec5 if self.lambda_t > 0.0: start_time = time.time() len_batch = min(len(one_batch), 50) for i in range(len_batch): _, _ = self.model.forward(one_batch[i:i + 1, :, :, :], arch) end_time = time.time() time_per = (end_time - start_time) / len_batch else: time_per = 0.0 #print('top1: {:.2f} top5: {:.2f}'.format(base_top1.avg * 100, base_top5.avg * 100)) return base_top1.avg, base_top5.avg, time_per
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter( ), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter( ), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) network.module.set_cal_mode('urs') network.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter( ), AverageMeter() network.train() end = time.time() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the weights network.module.random_genotype(True) w_optimizer.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = ( "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=base_losses, top1=base_top1, top5=base_top5) logger.log(Sstr + " " + Tstr + " " + Wstr) return base_losses.avg, base_top1.avg, base_top5.avg
def procedure( xloader, network, criterion, optimizer, metric, mode: Text, logger_fn: Callable = None, ): data_time, batch_time = AverageMeter(), AverageMeter() if mode.lower() == "train": network.train() elif mode.lower() == "valid": network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) end = time.time() for i, (inputs, targets) in enumerate(xloader): # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss if mode == "train": optimizer.zero_grad() outputs = network(inputs) targets = targets.to(get_device(outputs)) if mode == "train": loss = criterion(outputs, targets) loss.backward() optimizer.step() # record with torch.no_grad(): results = metric(outputs, targets) # measure elapsed time batch_time.update(time.time() - end) end = time.time() return metric.get_info()
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_inputs = base_inputs.cuda(non_blocking=True) arch_inputs = arch_inputs.cuda(non_blocking=True) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # Update the weights network.zero_grad() _, logits, _ = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update (base_prec1.item(), base_inputs.size(0)) base_top5.update (base_prec5.item(), base_inputs.size(0)) # update the architecture-weight network.zero_grad() _, logits, log_probs = network(arch_inputs) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) if algo == 'tunas': with torch.no_grad(): RL_BASELINE_EMA.update(arch_prec1.item()) rl_advantage = arch_prec1 - RL_BASELINE_EMA.value rl_log_prob = sum(log_probs) arch_loss = - rl_advantage * rl_log_prob elif algo == 'tas' or algo == 'fbv2': arch_loss = criterion(logits, arch_targets) else: raise ValueError('invalid algorightm name: {:}'.format(algo)) arch_loss.backward() a_optimizer.step() # record arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def valid_func(model, dset_val, criterion, device = 'cuda' if torch.cuda.is_available() else 'cpu', print_results=True): model.eval() val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32) val_meter = AverageMeter() val_acc_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch x = x.to(device) y = y.to(device) y_pred = model(x) if isinstance(criterion, torch.nn.CrossEntropyLoss): predicted = torch.argmax(y_pred, dim=1) correct = torch.sum((predicted == y)).item() total = predicted.size()[0] val_acc_meter.update(correct/total) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) if print_results: print("Val loss: {}, Val acc: {}".format(val_meter.avg, val_acc_meter.avg if val_acc_meter.avg > 0 else "Not applicable")) return val_meter
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): # print(111111111111111111111) # print(arch_inputs.size()) # print(arch_targets.size()) scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the architecture-weight a_optimizer.zero_grad() arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 2)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) # update the weights w_optimizer.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() torch.nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 2)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update (base_prec1.item(), base_inputs.size(0)) base_top5.update (base_prec5.item(), base_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) # Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg
def train_bptt( num_epochs: int, model, criterion, w_optimizer, a_optimizer, dset_train, dset_val, batch_size: int, T: int, w_checkpoint_freq: int, grad_clip: float, w_lr: float, logging_freq: int, grad_inner_loop_order: int, grad_outer_loop_order:int, hvp: str, arch_train_data:str, normalize_a_lr:bool, log_grad_norm:bool, log_alphas:bool, w_warm_start:int, extra_weight_decay:float, train_arch:bool, device:str ): train_loader = torch.utils.data.DataLoader( dset_train, batch_size=batch_size * T, shuffle=True ) val_loader = torch.utils.data.DataLoader(dset_val, batch_size=batch_size) grad_compute_speed = AverageMeter() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' for epoch in range(num_epochs): model.train() epoch_loss = AverageMeter() true_batch_index = 0 val_iter = iter(val_loader) for batch_idx, batch in enumerate(train_loader): xs, ys = torch.split(batch[0], batch_size), torch.split( batch[1], batch_size ) weight_buffer = WeightBuffer(T=T, checkpoint_freq=w_checkpoint_freq) weight_buffer.add(model, 0) for intra_batch_idx, (x, y) in enumerate(zip(xs, ys),1): x = x.to(device) y = y.to(device) # weight_buffer.add(model, intra_batch_idx) # TODO Should it be added here? y_pred = model(x) param_norm = 0 if extra_weight_decay is not None and extra_weight_decay != 0: for n,weight in model.named_weight_params(): if 'weight' in n: param_norm = param_norm + torch.pow(weight.norm(2), 2) param_norm = torch.multiply(model.alpha_weight_decay, param_norm) # print(param_norm) loss = criterion(y_pred, y) + param_norm epoch_loss.update(loss.item()) grads = torch.autograd.grad( loss, model.weight_params() ) with torch.no_grad(): for g, w in zip(grads, model.weight_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.weight_params(), 1) w_optimizer.step() w_optimizer.zero_grad() weight_buffer.add(model, intra_batch_idx) true_batch_index += 1 wandb.log( { "Train loss": epoch_loss.avg, "Epoch": epoch, "Batch": true_batch_index, } ) if true_batch_index % logging_freq == 0: print( "Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format( epoch, true_batch_index, epoch_loss.avg, [x.data for x in model.arch_params()], ) ) if train_arch: val_xs = None val_ys = None if arch_train_data == "val": try: val_batch = next(val_iter) val_xs, val_ys = torch.split(val_batch[0], batch_size), torch.split( val_batch[1], batch_size ) except: val_iter = iter(val_loader) val_batch = next(val_iter) val_xs, val_ys = torch.split(val_batch[0], batch_size), torch.split( val_batch[1], batch_size ) if epoch >= w_warm_start: start_time = time.time() total_arch_gradient = sotl_gradient( model=model, criterion=criterion, xs=xs, ys=ys, weight_buffer=weight_buffer, w_lr=w_lr, hvp=hvp, grad_inner_loop_order=grad_inner_loop_order, grad_outer_loop_order=grad_outer_loop_order, T=T, normalize_a_lr=normalize_a_lr, weight_decay_term=None, val_xs=val_xs, val_ys=val_ys ) grad_compute_speed.update(time.time() - start_time) if log_grad_norm: norm = 0 for g in total_arch_gradient: norm = norm + g.data.norm(2).item() wandb.log({"Arch grad norm": norm}) if log_alphas: if hasattr(model, "fc1") and hasattr(model.fc1, "degree"): wandb.log({"Alpha":model.fc1.degree.item()}) if hasattr(model,"alpha_weight_decay"): wandb.log({"Alpha": model.alpha_weight_decay.item()}) a_optimizer.zero_grad() for g, w in zip(total_arch_gradient, model.arch_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.arch_params(), 1) a_optimizer.step() val_results = valid_func( model=model, dset_val=dset_val, criterion=criterion, device=device, print_results=False ) print("Epoch: {}, Val Loss: {}".format(epoch, val_results.avg)) wandb.log({"Val loss": val_results.avg, "Epoch": epoch}) wandb.run.summary["Grad compute speed"] = grad_compute_speed.avg print(f"Grad compute speed: {grad_compute_speed.avg}s")
def main(args): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True prepare_seed(args.rand_seed) logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file()) logger = Logger(args.save_path, logstr) logger.log('Main Function with logger : {:}'.format(logger)) logger.log('Arguments : -------------------------------') for name, value in args._get_kwargs(): logger.log('{:16} : {:}'.format(name, value)) logger.log("Python version : {}".format(sys.version.replace('\n', ' '))) logger.log("Pillow version : {}".format(PIL.__version__)) logger.log("PyTorch version : {}".format(torch.__version__)) logger.log("cuDNN version : {}".format(torch.backends.cudnn.version())) # General Data Argumentation mean_fill = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] ) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max) train_transform = [transforms.PreCrop(args.pre_crop_expand)] train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))] train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)] #if args.arg_flip: # train_transform += [transforms.AugHorizontalFlip()] if args.rotate_max: train_transform += [transforms.AugRotate(args.rotate_max)] train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)] train_transform += [transforms.ToTensor(), normalize] train_transform = transforms.Compose( train_transform ) eval_transform = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)), transforms.ToTensor(), normalize]) assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval) # Model Configure Load model_config = load_configure(args.model_config, logger) args.sigma = args.sigma * args.scale_eval logger.log('Real Sigma : {:}'.format(args.sigma)) # Training Dataset train_data = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) train_data.load_list(args.train_lists, args.num_pts, True) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # Evaluation Dataloader eval_loaders = [] if args.eval_vlists is not None: for eval_vlist in args.eval_vlists: eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) eval_vdata.load_list(eval_vlist, args.num_pts, True) eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) eval_loaders.append((eval_vloader, True)) if args.eval_ilists is not None: for eval_ilist in args.eval_ilists: eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator) eval_idata.load_list(eval_ilist, args.num_pts, True) eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) eval_loaders.append((eval_iloader, False)) # Define network logger.log('configure : {:}'.format(model_config)) net = obtain_model(model_config, args.num_pts + 1) assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample) logger.log("=> network :\n {}".format(net)) logger.log('Training-data : {:}'.format(train_data)) for i, eval_loader in enumerate(eval_loaders): eval_loader, is_video = eval_loader logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset)) logger.log('arguments : {:}'.format(args)) opt_config = load_configure(args.opt_config, logger) if hasattr(net, 'specify_parameter'): net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay) else: net_param_dict = net.parameters() optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger) logger.log('criterion : {:}'.format(criterion)) net, criterion = net.cuda(), criterion.cuda() net = torch.nn.DataParallel(net) last_info = logger.last_info() if last_info.exists(): logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] + 1 checkpoint = torch.load(last_info['last_checkpoint']) assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch']) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch'])) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch = 0 if args.eval_once: logger.log("=> only evaluate the model once") eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config) logger.close() ; return # Main Training and Evaluation Loop start_time = time.time() epoch_time = AverageMeter() for epoch in range(start_epoch, opt_config.epochs): scheduler.step() need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True) epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs) LRs = scheduler.get_lr() logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config)) # train for one epoch train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config) # log the results logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100)) # remember best prec@1 and save checkpoint save_path = save_checkpoint({ 'epoch': epoch, 'args' : deepcopy(args), 'arch' : model_config.arch, 'state_dict': net.state_dict(), 'scheduler' : scheduler.state_dict(), 'optimizer' : optimizer.state_dict(), }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger) last_info = save_checkpoint({ 'epoch': epoch, 'last_checkpoint': save_path, }, logger.last_info(), logger) eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.close()
def search_func( xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, ): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter( ), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter( ), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the weights sampled_arch = network.module.dync_genotype(True) network.module.set_cal_mode("dynamic", sampled_arch) # network.module.set_cal_mode( 'urs' ) network.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) # update the architecture-weight network.module.set_cal_mode("joint") network.zero_grad() _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) arch_loss.backward() a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = ( "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=base_losses, top1=base_top1, top5=base_top5) Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) # print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) # print (network.module.arch_parameters) return ( base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg, )
def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger): # config. (containing some necessary arg) # baseline: The baseline score (i.e. average val_acc) from the previous epoch data_time, batch_time = AverageMeter(), AverageMeter() GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter( ), AverageMeter(), time.time() controller_num_aggregate = 20 controller_train_steps = 50 controller_bl_dec = 0.99 controller_entropy_weight = 0.0001 network.eval() network.controller.train() network.controller.zero_grad() loader_iter = iter(xloader) for step in range(controller_train_steps * controller_num_aggregate): try: inputs, targets = next(loader_iter) except: loader_iter = iter(xloader) inputs, targets = next(loader_iter) inputs = inputs.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) log_prob, entropy, sampled_arch = network.controller() with torch.no_grad(): network.set_cal_mode('dynamic', sampled_arch) _, logits = network(inputs) val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) val_top1 = val_top1.view(-1) / 100 reward = val_top1 + controller_entropy_weight * entropy if prev_baseline is None: baseline = val_top1 else: baseline = prev_baseline - (1 - controller_bl_dec) * ( prev_baseline - reward) loss = -1 * log_prob * (reward - baseline) # account RewardMeter.update(reward.item()) BaselineMeter.update(baseline.item()) ValAccMeter.update(val_top1.item() * 100) LossMeter.update(loss.item()) EntropyMeter.update(entropy.item()) # Average gradient over controller_num_aggregate samples loss = loss / controller_num_aggregate loss.backward(retain_graph=True) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() if (step + 1) % controller_num_aggregate == 0: grad_norm = torch.nn.utils.clip_grad_norm_( network.controller.parameters(), 5.0) GradnormMeter.update(grad_norm) optimizer.step() network.controller.zero_grad() if step % print_freq == 0: Sstr = '*Train-Controller* ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format( epoch_str, step, controller_train_steps * controller_num_aggregate) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format( loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter) Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr) return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter( ), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter( ), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_inputs = base_inputs.cuda(non_blocking=True) arch_inputs = arch_inputs.cuda(non_blocking=True) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # Update the weights if algo == 'setn': sampled_arch = network.dync_genotype(True) network.set_cal_mode('dynamic', sampled_arch) elif algo == 'gdas': network.set_cal_mode('gdas', None) elif algo.startswith('darts'): network.set_cal_mode('joint', None) elif algo == 'random': network.set_cal_mode('urs', None) elif algo == 'enas': with torch.no_grad(): network.controller.eval() _, _, sampled_arch = network.controller() network.set_cal_mode('dynamic', sampled_arch) else: raise ValueError('Invalid algo name : {:}'.format(algo)) network.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) # update the architecture-weight if algo == 'setn': network.set_cal_mode('joint') elif algo == 'gdas': network.set_cal_mode('gdas', None) elif algo.startswith('darts'): network.set_cal_mode('joint', None) elif algo == 'random': network.set_cal_mode('urs', None) elif algo != 'enas': raise ValueError('Invalid algo name : {:}'.format(algo)) network.zero_grad() if algo == 'darts-v2': arch_loss, logits = backward_step_unrolled( network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) a_optimizer.step() elif algo == 'random' or algo == 'enas': with torch.no_grad(): _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) else: _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) arch_loss.backward() a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter( ), AverageMeter(), AverageMeter() arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() epoch_str, flop_need, flop_weight, flop_tolerant = extra_info[ 'epoch-str'], extra_info['FLOP-exp'], extra_info[ 'FLOP-weight'], extra_info['FLOP-tolerant'] network.train() logger.log( '[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format( epoch_str, flop_need, flop_weight)) end = time.time() network.apply(change_key('search_mode', 'search')) for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): scheduler.update(None, 1.0 * step / len(search_loader)) # calculate prediction and loss base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the weights base_optimizer.zero_grad() logits, expected_flop = network(base_inputs) # network.apply( change_key('search_mode', 'basic') ) # features, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() base_optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) top1.update(prec1.item(), base_inputs.size(0)) top5.update(prec5.item(), base_inputs.size(0)) # update the architecture arch_optimizer.zero_grad() logits, expected_flop = network(arch_inputs) flop_cur = network.module.get_flop('genotype', None, None) flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant) acls_loss = criterion(logits, arch_targets) arch_loss = acls_loss + flop_loss * flop_weight arch_loss.backward() arch_optimizer.step() # record arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0)) arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or (step + 1) == len(search_loader): Sstr = '**TRAIN** ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format( loss=base_losses, top1=top1, top5=top5) Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format( aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) # print(network.module.get_arch_info()) # print(network.module.width_attentions[0]) # print(network.module.width_attentions[1]) logger.log( ' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}' .format(top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def basic_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config): args = deepcopy(args) batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() visible_points, losses = AverageMeter(), AverageMeter() eval_meta = Eval_Meta() cpu = torch.device('cpu') # switch to train mode net.train() criterion.train() end = time.time() for i, (inputs, target, mask, points, image_index, nopoints, cropped_size) in enumerate(loader): # inputs : Batch, Channel, Height, Width target = target.cuda(non_blocking=True) image_index = image_index.numpy().squeeze(1).tolist() batch_size, num_pts = inputs.size(0), args.num_pts visible_point_num = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size visible_points.update(visible_point_num, batch_size) nopoints = nopoints.numpy().squeeze(1).tolist() annotated_num = batch_size - sum(nopoints) # measure data loading time mask = mask.cuda(non_blocking=True) data_time.update(time.time() - end) # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W] batch_heatmaps, batch_locs, batch_scos = net(inputs) forward_time.update(time.time() - end) loss, each_stage_loss_value = compute_stage_loss(criterion, target, batch_heatmaps, mask) if opt_config.lossnorm: loss, each_stage_loss_value = loss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value] # measure accuracy and record loss losses.update(loss.item(), batch_size) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() eval_time.update(time.time() - end) np_batch_locs, np_batch_scos = batch_locs.detach().to(cpu).numpy(), batch_scos.detach().to(cpu).numpy() cropped_size = cropped_size.numpy() # evaluate the training data for ibatch, (imgidx, nopoint) in enumerate(zip(image_index, nopoints)): if nopoint == 1: continue locations, scores = np_batch_locs[ibatch,:-1,:], np.expand_dims(np_batch_scos[ibatch,:-1], -1) xpoints = loader.dataset.labels[imgidx].get_points() assert cropped_size[ibatch,0] > 0 and cropped_size[ibatch,1] > 0, 'The ibatch={:}, imgidx={:} is not right.'.format(ibatch, imgidx, cropped_size[ibatch]) scale_h, scale_w = cropped_size[ibatch,0] * 1. / inputs.size(-2) , cropped_size[ibatch,1] * 1. / inputs.size(-1) locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[ibatch,2], locations[:, 1] * scale_h + cropped_size[ibatch,3] assert xpoints.shape[1] == num_pts and locations.shape[0] == num_pts and scores.shape[0] == num_pts, 'The number of points is {} vs {} vs {} vs {}'.format(num_pts, xpoints.shape, locations.shape, scores.shape) # recover the original resolution prediction = np.concatenate((locations, scores), axis=1).transpose(1,0) image_path = loader.dataset.datas[imgidx] face_size = loader.dataset.face_sizes[imgidx] eval_meta.append(prediction, xpoints, image_path, face_size) # measure elapsed time batch_time.update(time.time() - end) last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True) end = time.time() if i % args.print_freq == 0 or i+1 == len(loader): logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] ' 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) '.format( epoch_str, i, len(loader), batch_time=batch_time, data_time=data_time, forward_time=forward_time, loss=losses) + last_time + show_stage_loss(each_stage_loss_value) \ + ' In={:} Tar={:}'.format(list(inputs.size()), list(target.size())) \ + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg)) nme, _, _ = eval_meta.compute_mse(logger) return losses.avg, nme
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter() Ttop1, Ttop5 = AverageMeter(), AverageMeter() if mode == 'train': network.train() elif mode == 'valid': network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) teacher.eval() logger.log( '[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]' .format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature)) end = time.time() for i, (inputs, targets) in enumerate(xloader): if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss targets = targets.cuda(non_blocking=True) if mode == 'train': optimizer.zero_grad() student_f, logits = network(inputs) if isinstance(logits, list): assert len( logits ) == 2, 'logits must has {:} items instead of {:}'.format( 2, len(logits)) logits, logits_aux = logits else: logits, logits_aux = logits, None with torch.no_grad(): teacher_f, teacher_logits = teacher(inputs) loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature) if config is not None and hasattr( config, 'auxiliary') and config.auxiliary > 0: loss_aux = criterion(logits_aux, targets) loss += config.auxiliary * loss_aux if mode == 'train': loss.backward() optimizer.step() # record sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(sprec1.item(), inputs.size(0)) top5.update(sprec5.item(), inputs.size(0)) # teacher tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5)) Ttop1.update(tprec1.item(), inputs.size(0)) Ttop5.update(tprec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 or (i + 1) == len(xloader): Sstr = ' {:5s} '.format( mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format( extra_info, i, len(xloader)) if scheduler is not None: Sstr += ' {:}'.format(scheduler.get_min_info()) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format( loss=losses, top1=top1, top5=top5) Lstr += ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format( Ttop1.avg, Ttop5.avg) Istr = 'Size={:}'.format(list(inputs.size())) logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format( mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg)) logger.log( ' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}' .format(mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg)) return losses.avg, top1.avg, top5.avg
def stm_main_heatmap(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config, stm_config, use_stm, mode): assert mode == 'train' or mode == 'test', 'invalid mode : {:}'.format(mode) args = copy.deepcopy(args) batch_time, data_time, forward_time, eval_time = AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter() visible_points, DetLosses, TemporalLosses, MultiviewLosses, TotalLosses = AverageMeter( ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() alk_points, a3d_points = AverageMeter(), AverageMeter() annotate_index = loader.dataset.video_L eval_meta = Eval_Meta() cpu = torch.device('cpu') if args.debug: save_dir = Path( args.save_path) / 'DEBUG' / ('{:}-'.format(mode) + epoch_str) else: save_dir = None # switch to train mode if mode == 'train': logger.log('STM-Main-REG : training : {:} .. STM = {:}'.format( stm_config, use_stm)) print_freq = args.print_freq net.train() criterion.train() else: logger.log('STM-Main-REG : evaluation mode.') print_freq = args.print_freq_eval net.eval() criterion.eval() i_batch_size, v_batch_size, m_batch_size = args.i_batch_size, args.v_batch_size, args.m_batch_size iv_size = i_batch_size + v_batch_size end = time.time() for i, (frames, Fflows, Bflows, targets, masks, normpoints, transthetas, MV_Tensors, MV_Thetas, MV_Shapes, MV_KRT, torch_is_3D, torch_is_images \ , image_index, nopoints, shapes, MultiViewPaths) in enumerate(loader): # frames : IBatch+VBatch+MBatch, Frame, Channel, Height, Width # Fflows : IBatch+VBatch+MBatch, Frame-1, Height, Width, 2 # Bflows : IBatch+VBatch+MBatch, Frame-1, Height, Width, 2 # information MV_Mask = masks[iv_size:] frames, Fflows, Bflows, targets, masks, normpoints, transthetas = frames[: iv_size], Fflows[: iv_size], Bflows[: iv_size], targets[: iv_size], masks[: iv_size], normpoints[: iv_size], transthetas[: iv_size] nopoints, shapes, torch_is_images = nopoints[: iv_size], shapes[: iv_size], torch_is_images[: iv_size] MV_Tensors, MV_Thetas, MV_Shapes, MV_KRT, torch_is_3D = \ MV_Tensors[iv_size:], MV_Thetas[iv_size:], MV_Shapes[iv_size:], MV_KRT[iv_size:], torch_is_3D[iv_size:] assert torch.sum(torch_is_images[:i_batch_size]).item( ) == i_batch_size, 'Image Check Fail : {:} vs. {:}'.format( torch_is_images[:i_batch_size], i_batch_size) assert v_batch_size == 0 or torch.sum( torch_is_images[i_batch_size:]).item( ) == 0, 'Video Check Fail : {:} vs. {:}'.format( torch_is_images[i_batch_size:], v_batch_size) assert torch_is_3D.sum().item( ) == m_batch_size, 'Multiview Check Fail : {:} vs. {:}'.format( torch_is_3D, m_batch_size) image_index = image_index.squeeze(1).tolist() (batch_size, frame_length, C, H, W), num_pts, num_views = frames.size( ), args.num_pts, stm_config.max_views visible_point_num = float(np.sum( masks.numpy()[:, :-1, :, :])) / batch_size visible_points.update(visible_point_num, batch_size) normpoints = normpoints.permute(0, 2, 1) target_heats = targets.cuda(non_blocking=True) target_points = normpoints[:, :, :2].contiguous().cuda( non_blocking=True) target_scores = normpoints[:, :, 2:].contiguous().cuda(non_blocking=True) det_masks = (1 - nopoints).view(batch_size, 1, 1, 1) * masks have_det_loss = det_masks.sum().item() > 0 det_masks = det_masks.cuda(non_blocking=True) nopoints = nopoints.squeeze(1).tolist() # measure data loading time data_time.update(time.time() - end) # batch_heatmaps is a list for stage-predictions, each element should be [Batch, Sequence, PTS, H/Down, W/Down] batch_heatmaps, batch_locs, batch_scos, batch_past2now, batch_future2now, batch_FBcheck, multiview_heatmaps, multiview_locs = net( frames, Fflows, Bflows, MV_Tensors, torch_is_images) annot_heatmaps = [x[:, annotate_index] for x in batch_heatmaps] forward_time.update(time.time() - end) # detection loss if have_det_loss: det_loss, each_stage_loss_value = compute_stage_loss( criterion, target_heats, annot_heatmaps, det_masks) DetLosses.update(det_loss.item(), batch_size) each_stage_loss_value = show_stage_loss(each_stage_loss_value) else: det_loss, each_stage_loss_value = 0, 'no-det-loss' # temporal loss if use_stm[0]: video_batch_locs = batch_locs[i_batch_size:, :, :num_pts] video_past2now, video_future2now = batch_past2now[ i_batch_size:, :, :num_pts], batch_future2now[ i_batch_size:, :, :num_pts] video_FBcheck = batch_FBcheck[i_batch_size:, :num_pts] video_mask = masks[i_batch_size:, :num_pts].contiguous().cuda( non_blocking=True) video_heatmaps = [ x[i_batch_size:, :, :num_pts] for x in batch_heatmaps ] sbr_loss, available_nums, loss_string = calculate_temporal_loss( criterion, video_heatmaps, video_batch_locs, video_past2now, video_future2now, video_FBcheck, video_mask, stm_config) alk_points.update( float(available_nums) / v_batch_size, v_batch_size) if available_nums > stm_config.available_sbr_thresh: TemporalLosses.update(sbr_loss.item(), v_batch_size) else: sbr_loss, sbr_loss_string = 0, 'non-sbr-loss' else: sbr_loss, sbr_loss_string = 0, 'non-sbr-loss' # multiview loss if use_stm[1]: MV_Mask_G = MV_Mask[:, :-1].view( m_batch_size, 1, -1, 1).contiguous().cuda(non_blocking=True) MV_Thetas_G = MV_Thetas.to(multiview_locs.device) MV_Shapes_G = MV_Shapes.to(multiview_locs.device).view( m_batch_size, num_views, 1, 2) MV_KRT_G = MV_KRT.to(multiview_locs.device) mv_norm_locs_trs = torch.cat( (multiview_locs[:, :, :num_pts].permute(0, 1, 3, 2), torch.ones(m_batch_size, num_views, 1, num_pts, device=multiview_locs.device)), dim=2) mv_norm_locs_ori = torch.matmul(MV_Thetas_G[:, :, :2], mv_norm_locs_trs) mv_norm_locs_ori = mv_norm_locs_ori.permute(0, 1, 3, 2) mv_real_locs_ori = denormalize_L(mv_norm_locs_ori, MV_Shapes_G) mv_3D_locs_ori = TriangulateDLT_BatchCam(MV_KRT_G, mv_real_locs_ori) mv_proj_locs_ori = ProjectKRT_Batch( MV_KRT_G, mv_3D_locs_ori.view(m_batch_size, 1, num_pts, 3)) mv_pnorm_locs_ori = normalize_L(mv_proj_locs_ori, MV_Shapes_G) mv_pnorm_locs_trs = convert_theta(mv_pnorm_locs_ori, MV_Thetas_G) MV_locs = multiview_locs[:, :, :num_pts].contiguous() MV_heatmaps = [x[:, :, :num_pts] for x in multiview_heatmaps] if args.debug: with torch.no_grad(): for ims in range(m_batch_size): x_index = image_index[iv_size + ims] x_paths = [ xlist[iv_size + ims] for xlist in MultiViewPaths ] x_mv_locs, p_mv_locs = mv_real_locs_ori[ ims], mv_proj_locs_ori[ims] multiview_debug_save(save_dir, '{:}'.format(x_index), x_paths, x_mv_locs.cpu().numpy(), p_mv_locs.cpu().numpy()) y_mv_locs = denormalize_points_batch((H, W), MV_locs[ims]) q_mv_locs = denormalize_points_batch( (H, W), mv_pnorm_locs_trs[ims]) temp_tensors = MV_Tensors[ims] temp_images = [ args.tensor2imageF(x) for x in temp_tensors ] temp_names = [Path(x).name for x in x_paths] multiview_debug_save_v2(save_dir, '{:}'.format(x_index), temp_names, temp_images, y_mv_locs.cpu().numpy(), q_mv_locs.cpu().numpy()) stm_loss, available_nums = calculate_multiview_loss( criterion, MV_heatmaps, MV_locs, mv_pnorm_locs_trs, MV_Mask_G, stm_config) a3d_points.update( float(available_nums) / m_batch_size, m_batch_size) if available_nums > stm_config.available_stm_thresh: MultiviewLosses.update(stm_loss.item(), m_batch_size) else: stm_loss = 0 else: stm_loss = 0 # measure accuracy and record loss if use_stm[0]: total_loss = det_loss + sbr_loss * stm_config.sbr_weights + stm_loss * stm_config.stm_weights else: total_loss = det_loss + stm_loss * stm_config.stm_weights if isinstance(total_loss, numbers.Number): warnings.warn( 'The {:}-th iteration has no detection loss and no lk loss'. format(i)) else: TotalLosses.update(total_loss.item(), batch_size) # compute gradient and do SGD step if mode == 'train': # training mode optimizer.zero_grad() total_loss.backward() optimizer.step() eval_time.update(time.time() - end) with torch.no_grad(): batch_locs = batch_locs.detach().to(cpu)[:, annotate_index, :num_pts] batch_scos = batch_scos.detach().to(cpu)[:, annotate_index, :num_pts] # evaluate the training data for ibatch in range(iv_size): imgidx, nopoint = image_index[ibatch], nopoints[ibatch] if nopoint == 1: continue norm_locs = torch.cat( (batch_locs[ibatch].permute(1, 0), torch.ones(1, num_pts)), dim=0) transtheta = transthetas[ibatch][:2, :] norm_locs = torch.mm(transtheta, norm_locs) real_locs = denormalize_points(shapes[ibatch].tolist(), norm_locs) real_locs = torch.cat( (real_locs, batch_scos[ibatch].view(1, num_pts)), dim=0) image_path = loader.dataset.datas[imgidx][annotate_index] normDistce = loader.dataset.NormDistances[imgidx] xpoints = loader.dataset.labels[imgidx].get_points() eval_meta.append(real_locs.numpy(), xpoints.numpy(), image_path, normDistce) if save_dir: pro_debug_save(save_dir, Path(image_path).name, frames[ibatch, annotate_index], targets[ibatch], normpoints[ibatch], meanthetas[ibatch], batch_heatmaps[-1][ibatch, annotate_index], args.tensor2imageF) # measure elapsed time batch_time.update(time.time() - end) last_time = convert_secs2time(batch_time.avg * (len(loader) - i - 1), True) end = time.time() if i % print_freq == 0 or i + 1 == len(loader): logger.log(' -->>[{:}]: [{:}][{:03d}/{:03d}] ' 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 'F-time {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 'Det {dloss.val:7.4f} ({dloss.avg:7.4f}) ' 'SBR {sloss.val:7.6f} ({sloss.avg:7.6f}) ' 'STM {mloss.val:7.6f} ({mloss.avg:7.6f}) ' 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) '.format( mode, epoch_str, i, len(loader), batch_time=batch_time, data_time=data_time, forward_time=forward_time, \ dloss=DetLosses, sloss=TemporalLosses, mloss=MultiviewLosses, loss=TotalLosses) + last_time + each_stage_loss_value \ + ' I={:}'.format(list(frames.size())) \ + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg) \ + ' Ava-PTS : {:.1f} ({:.1f})'.format(alk_points.val, alk_points.avg) \ + ' A3D-PTS : {:.1f} ({:.1f})'.format(a3d_points.val, a3d_points.avg) ) if args.debug: logger.log(' -->>Indexes : {:}'.format(image_index)) nme, _, _ = eval_meta.compute_mse(loader.dataset.dataset_name, logger) return TotalLosses.avg, nme
def lk_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config, lk_config, use_lk): args = deepcopy(args) batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() visible_points, detlosses, lklosses = AverageMeter(), AverageMeter(), AverageMeter() alk_points, losses = AverageMeter(), AverageMeter() cpu = torch.device('cpu') annotate_index = loader.dataset.center_idx # switch to train mode net.train() criterion.train() end = time.time() for i, (inputs, target, mask, points, image_index, nopoints, video_or_not, cropped_size) in enumerate(loader): # inputs : Batch, Sequence Channel, Height, Width target = target.cuda(non_blocking=True) image_index = image_index.numpy().squeeze(1).tolist() batch_size, sequence, num_pts = inputs.size(0), inputs.size(1), args.num_pts mask_np = mask.numpy().squeeze(-1).squeeze(-1) visible_point_num = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size visible_points.update(visible_point_num, batch_size) nopoints = nopoints.numpy().squeeze(1).tolist() video_or_not= video_or_not.numpy().squeeze(1).tolist() annotated_num = batch_size - sum(nopoints) # measure data loading time mask = mask.cuda(non_blocking=True) data_time.update(time.time() - end) # batch_heatmaps is a list for stage-predictions, each element should be [Batch, Sequence, PTS, H/Down, W/Down] batch_heatmaps, batch_locs, batch_scos, batch_next, batch_fback, batch_back = net(inputs) annot_heatmaps = [x[:, annotate_index] for x in batch_heatmaps] forward_time.update(time.time() - end) if annotated_num > 0: # have the detection loss detloss, each_stage_loss_value = compute_stage_loss(criterion, target, annot_heatmaps, mask) if opt_config.lossnorm: detloss, each_stage_loss_value = detloss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value] # measure accuracy and record loss detlosses.update(detloss.item(), batch_size) each_stage_loss_value = show_stage_loss(each_stage_loss_value) else: detloss, each_stage_loss_value = 0, 'no-det-loss' if use_lk: lkloss, avaliable = lk_target_loss(batch_locs, batch_scos, batch_next, batch_fback, batch_back, lk_config, video_or_not, mask_np, nopoints) if lkloss is not None: lklosses.update(lkloss.item(), avaliable) else: lkloss = 0 alk_points.update(float(avaliable)/batch_size, batch_size) else : lkloss = 0 loss = detloss + lkloss * lk_config.weight if isinstance(loss, numbers.Number): warnings.warn('The {:}-th iteration has no detection loss and no lk loss'.format(i)) else: losses.update(loss.item(), batch_size) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() eval_time.update(time.time() - end) # measure elapsed time batch_time.update(time.time() - end) last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True) end = time.time() if i % args.print_freq == 0 or i+1 == len(loader): logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] ' 'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) ' 'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) ' 'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) ' 'Loss {loss.val:7.4f} ({loss.avg:7.4f}) [LK={lk.val:7.4f} ({lk.avg:7.4f})] '.format( epoch_str, i, len(loader), batch_time=batch_time, data_time=data_time, forward_time=forward_time, loss=losses, lk=lklosses) + each_stage_loss_value + ' ' + last_time \ + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg) \ + ' Ava-PTS : {:.1f} ({:.1f})'.format(alk_points.val, alk_points.avg)) return losses.avg