def valid_func(xloader, network, criterion): data_time, batch_time = AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter( ), AverageMeter() network.eval() end = time.time() with torch.no_grad(): for step, (arch_inputs, arch_targets) in enumerate(xloader): arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # prediction _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update(arch_prec1.item(), arch_inputs.size(0)) arch_top5.update(arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() return arch_losses.avg, arch_top1.avg, arch_top5.avg
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() if mode == 'train': network.train() elif mode == 'valid': network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) device = torch.cuda.current_device() data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() for i, (inputs, targets) in enumerate(xloader): if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) targets = targets.cuda(device=device, non_blocking=True) if mode == 'train': optimizer.zero_grad() # forward features, logits = network(inputs) loss = criterion(logits, targets) # backward if mode == 'train': loss.backward() optimizer.step() # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # count time batch_time.update(time.time() - end) end = time.time() return losses.avg, top1.avg, top5.avg, batch_time.sum
def test_contrastive(args, model, nearest_proto_model, device, test_loader_creator_l, logger): model.eval() acc = AverageMeter() tasks_acc = [ AverageMeter() for i in range(len(test_loader_creator_l.data_loaders)) ] test_loaders_l = test_loader_creator_l.data_loaders with torch.no_grad(): for task_idx, test_loader_l in enumerate(test_loaders_l): for batch_idx, (data, _, target) in enumerate(test_loader_l): data, target = data.to(device), target.to(device) cur_feats, _ = model(data) output = nearest_proto_model.predict(cur_feats) it_acc = (output == target).sum().item() / data.shape[0] acc.update(it_acc, data.size(0)) tasks_acc[task_idx].update(it_acc, data.size(0)) if args.acc_per_task: tasks_acc_str = 'Tess Acc per task: ' for i, task_acc in enumerate(tasks_acc): tasks_acc_str += 'Task{:2d} Acc: {acc.avg:.3f}'.format( (i + 1), acc=task_acc) + '\t' logger.info(tasks_acc_str) logger.info('Test Acc: {acc.avg:.3f}'.format(acc=acc))
def train_or_test_epoch(self, xloader, model, loss_fn, metric_fn, is_train, optimizer=None): if is_train: model.train() else: model.eval() score_meter, loss_meter = AverageMeter(), AverageMeter() for ibatch, (feats, labels) in enumerate(xloader): feats = feats.to(self.device, non_blocking=True) labels = labels.to(self.device, non_blocking=True) # forward the network preds = model(feats) loss = loss_fn(preds, labels) with torch.no_grad(): score = self.metric_fn(preds, labels) loss_meter.update(loss.item(), feats.size(0)) score_meter.update(score.item(), feats.size(0)) # optimize the network if is_train and optimizer is not None: optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) optimizer.step() return loss_meter.avg, score_meter.avg
def test(args, model, device, test_loader_creator, logger): model.eval() criterion = torch.nn.CrossEntropyLoss().to(device) with torch.no_grad(): losses = AverageMeter() acc = AverageMeter() for test_loader in test_loader_creator.data_loaders: for data, target in test_loader: data, target = data.to(device), target.to(device) _, output = model(data) loss = criterion(output, target) output = output.float() loss = loss.float() it_acc = accuracy(output.data, target)[0] losses.update(loss.item(), data.size(0)) acc.update(it_acc.item(), data.size(0)) logger.info('Test set: Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc {acc.avg:.3f}'.format(loss=losses, acc=acc))
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): data_time, batch_time, batch = AverageMeter(), AverageMeter(), None losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() latencies, device = [], torch.cuda.current_device() network.eval() with torch.no_grad(): end = time.time() for i, (inputs, targets) in enumerate(xloader): targets = targets.cuda(device=device, non_blocking=True) inputs = inputs.cuda(device=device, non_blocking=True) data_time.update(time.time() - end) # forward features, logits = network(inputs) loss = criterion(logits, targets) batch_time.update(time.time() - end) if batch is None or batch == inputs.size(0): batch = inputs.size(0) latencies.append(batch_time.val - data_time.val) # record loss and accuracy prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) end = time.time() if len(latencies) > 2: latencies = latencies[1:] return losses.avg, top1.avg, top5.avg, latencies
def train_shared_cnn( xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger, ): data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = ( AverageMeter(), AverageMeter(), AverageMeter(), time.time(), ) shared_cnn.train() controller.eval() for step, (inputs, targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) with torch.no_grad(): _, _, sampled_arch = controller() optimizer.zero_grad() shared_cnn.module.update_arch(sampled_arch) _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = ( "*Train-Shared-CNN* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=losses, top1=top1s, top5=top5s) logger.log(Sstr + " " + Tstr + " " + Wstr) return losses.avg, top1s.avg, top5s.avg
def search_valid(xloader, network, criterion, extra_info, print_freq, logger): data_time, batch_time, losses, top1, top5 = ( AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), ) network.eval() network.apply(change_key("search_mode", "search")) end = time.time() # logger.log('Starting evaluating {:}'.format(epoch_info)) with torch.no_grad(): for i, (inputs, targets) in enumerate(xloader): # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss targets = targets.cuda(non_blocking=True) logits, expected_flop = network(inputs) loss = criterion(logits, targets) # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 or (i + 1) == len(xloader): Sstr = ("**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format( loss=losses, top1=top1, top5=top5) Istr = "Size={:}".format(list(inputs.size())) logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr) logger.log( " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}" .format( top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg, )) return losses.avg, top1.avg, top5.avg
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter( ), time.time() shared_cnn.train() controller.eval() ne = 10 for ni in range(ne): with torch.no_grad(): _, _, sampled_arch = controller() shared_cnn.module.update_arch(sampled_arch) print(sampled_arch) # arch_str = op_list2str(sampled_arch) for step, (inputs, targets) in enumerate(xloader): # print(step,inputs,targets) scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) optimizer.zero_grad() _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 2)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() # if step + 1 == len(xloader): Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:03d}/10]'.format( ni, ne) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = '[Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format( loss=losses, top1=top1s, top5=top5s) losses.reset() top1s.reset() top5s.reset() logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) return losses.avg, top1s.avg, top5s.avg
def search(self): self.eva_time = AverageMeter() init_start = time.time() self.init_random() self.logger.log('Initial_takes: %.2f' % (time.time() - init_start)) epoch_start_time = time.time() epoch_time_meter = AverageMeter() bests_per_epoch = list() perform_trace = list() for i in range(self.max_epochs): self.performances = torch.Tensor(self.performances) top_k = torch.argsort(self.performances, descending=True)[:self.parent_num] if self.best_perf is None or self.performances[ top_k[0]] > self.best_perf: self.best_cand = self.candidates[top_k[0]] self.best_perf = self.performances[top_k[0]] bests_per_epoch.append(self.best_cand) perform_trace.append(self.performances) self.parents = [] for idx in top_k: self.parents.append(self.candidates[idx]) self.candidates, self.performances = list(), list() self.eva_time = AverageMeter() self.get_mutation(self.population_num // 2) self.get_crossover() self.logger.log( '*SEARCH* ' + time_string() + '||| Epoch: %2d finished, %3d models have been tested, best performance is %.2f' % (i, len(self.perform_dict.keys()), self.best_perf)) self.logger.log(' - Best Cand: ' + str(self.best_cand)) this_epoch_time = time.time() - epoch_start_time epoch_time_meter.update(this_epoch_time) epoch_start_time = time.time() self.logger.log('Time for Epoch %d : %.2fs' % (i, this_epoch_time)) self.logger.log(' -- Evaluated %d models, with %.2f s in average' % (self.eva_time.count, self.eva_time.avg)) self.logger.log( '--------\nSearching Finished. Best Arch Found with Acc %.2f' % (self.best_perf)) self.logger.log(str(self.best_cand)) #torch.save(self.best_cand, self.save_dir+'/best_arch.pth') #torch.save(self.perform_dict, self.save_dir+'/perform_dict.pth') return bests_per_epoch, self.perform_dict, perform_trace
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_inputs = base_inputs.cuda(non_blocking=True) arch_inputs = arch_inputs.cuda(non_blocking=True) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # Update the weights network.zero_grad() _, logits, _ = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update (base_prec1.item(), base_inputs.size(0)) base_top5.update (base_prec5.item(), base_inputs.size(0)) # update the architecture-weight network.zero_grad() _, logits, log_probs = network(arch_inputs) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) if algo == 'tunas': with torch.no_grad(): RL_BASELINE_EMA.update(arch_prec1.item()) rl_advantage = arch_prec1 - RL_BASELINE_EMA.value rl_log_prob = sum(log_probs) arch_loss = - rl_advantage * rl_log_prob elif algo == 'tas' or algo == 'fbv2': arch_loss = criterion(logits, arch_targets) else: raise ValueError('invalid algorightm name: {:}'.format(algo)) arch_loss.backward() a_optimizer.step() # record arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def test_archi_acc(self, arch): if self.train_loader is not None: self.model.apply(ResetRunningStats) self.model.train() for step, (data, target) in enumerate(self.train_loader): # print('train step: {} total: {}'.format(step,max_train_iters)) # data, target = train_dataprovider.next() # print('get data',data.shape) #data = data.cuda() output = self.model.forward(data, arch) #_with_architect del data, target, output base_top1, base_top5 = AverageMeter(), AverageMeter() self.model.eval() one_batch = None for step, (data, target) in enumerate(self.val_loader): # print('test step: {} total: {}'.format(step,max_test_iters)) if one_batch == None: one_batch = data batchsize = data.shape[0] # print('get data',data.shape) target = target.cuda(non_blocking=True) #data, target = data.to(device), target.to(device) _, logits = self.model.forward(data, arch) #_with_architect prec1, prec5 = obtain_accuracy(logits.data, target.data, topk=(1, 5)) base_top1.update(prec1.item(), batchsize) base_top5.update(prec5.item(), batchsize) del data, target, logits, prec1, prec5 if self.lambda_t > 0.0: start_time = time.time() len_batch = min(len(one_batch), 50) for i in range(len_batch): _, _ = self.model.forward(one_batch[i:i + 1, :, :, :], arch) end_time = time.time() time_per = (end_time - start_time) / len_batch else: time_per = 0.0 #print('top1: {:.2f} top5: {:.2f}'.format(base_top1.avg * 100, base_top5.avg * 100)) return base_top1.avg, base_top5.avg, time_per
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger): prepare_seed(seed) # random seed net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny', 'C': arch_config['channel'], 'N': arch_config['num_cells'], 'genotype': arch, 'num_classes': config.class_num} , None) ) #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) flop, param = get_model_infos(net, config.xshape) logger.log('Network : {:}'.format(net.get_message()), False) logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) # train and valid optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config) network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda() # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} train_times , valid_times = {}, {} for epoch in range(total_epoch): scheduler.update(epoch, 0.0) train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') train_losses[epoch] = train_loss train_acc1es[epoch] = train_acc1 train_acc5es[epoch] = train_acc5 train_times [epoch] = train_tm with torch.no_grad(): for key, xloder in valid_loaders.items(): valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid') valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1 valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5)) info_seed = {'flop' : flop, 'param': param, 'channel' : arch_config['channel'], 'num_cells' : arch_config['num_cells'], 'config' : config._asdict(), 'total_epoch' : total_epoch , 'train_losses': train_losses, 'train_acc1es': train_acc1es, 'train_acc5es': train_acc5es, 'train_times' : train_times, 'valid_losses': valid_losses, 'valid_acc1es': valid_acc1es, 'valid_acc5es': valid_acc5es, 'valid_times' : valid_times, 'net_state_dict': net.state_dict(), 'net_string' : '{:}'.format(net), 'finish-train': True } return info_seed
def train_bptt(num_epochs: int, model, dset_train, batch_size: int, T: int, w_checkpoint_freq: int, grad_clip: float, w_lr: float, logging_freq: int, sotl_order: int, hvp: str): model.train() train_loader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size * T, shuffle=True) for epoch in range(num_epochs): epoch_loss = AverageMeter() true_batch_index = 0 for batch_idx, batch in enumerate(train_loader): xs, ys = torch.split(batch[0], batch_size), torch.split( batch[1], batch_size) weight_buffer = WeightBuffer(T=T, checkpoint_freq=w_checkpoint_freq) for intra_batch_idx, (x, y) in enumerate(zip(xs, ys)): weight_buffer.add(model, intra_batch_idx) y_pred = model(x) loss = criterion(y_pred, y) epoch_loss.update(loss.item()) grads = torch.autograd.grad(loss, model.weight_params(), retain_graph=True, allow_unused=True, create_graph=True) w_optimizer.zero_grad() with torch.no_grad(): for g, w in zip(grads, model.weight_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.parameters(), 1) w_optimizer.step() true_batch_index += 1 if true_batch_index % logging_freq == 0: print("Epoch: {}, Batch: {}, Loss: {}".format( epoch, true_batch_index, epoch_loss.avg)) wandb.log({"Train loss": epoch_loss.avg}) total_arch_gradient = sotl_gradient(model, criterion, xs, ys, weight_buffer, w_lr=w_lr, hvp=hvp, order=sotl_order) a_optimizer.zero_grad() for g, w in zip(total_arch_gradient, model.arch_params()): w.grad = g torch.nn.utils.clip_grad_norm_(model.arch_params(), 1) a_optimizer.step()
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter( ), AverageMeter() network.train() end = time.time() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the weights network.module.random_genotype(True) w_optimizer.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update(base_prec1.item(), base_inputs.size(0)) base_top5.update(base_prec5.item(), base_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = ( "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format( batch_time=batch_time, data_time=data_time) Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format( loss=base_losses, top1=base_top1, top5=base_top5) logger.log(Sstr + " " + Tstr + " " + Wstr) return base_losses.avg, base_top1.avg, base_top5.avg
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() end = time.time() network.train() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the weights sampled_arch = network.module.dync_genotype(True) network.module.set_cal_mode('dynamic', sampled_arch) #network.module.set_cal_mode( 'urs' ) network.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update (base_prec1.item(), base_inputs.size(0)) base_top5.update (base_prec5.item(), base_inputs.size(0)) # update the architecture-weight network.module.set_cal_mode( 'joint' ) network.zero_grad() _, logits = network(arch_inputs) arch_loss = criterion(logits, arch_targets) arch_loss.backward() a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr) #print (nn.functional.softmax(network.module.arch_parameters, dim=-1)) #print (network.module.arch_parameters) return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text], splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any], to_evaluate_indexes: tuple, cover_mode: bool): log_dir = save_dir / 'logs' log_dir.mkdir(parents=True, exist_ok=True) logger = Logger(str(log_dir), os.getpid(), False) logger.log('xargs : seeds = {:}'.format(seeds)) logger.log('xargs : cover_mode = {:}'.format(cover_mode)) logger.log('-' * 100) logger.log( 'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes)) +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode)) for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): logger.log( '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) logger.log('--->>> optimization config : {:}'.format(opt_config)) #to_evaluate_indexes = list(range(srange[0], srange[1] + 1)) start_time, epoch_time = time.time(), AverageMeter() for i, index in enumerate(to_evaluate_indexes): channelstr = nets[index] logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15)) logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15)) # test this arch on different datasets with different seeds has_continue = False for seed in seeds: to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) if to_save_name.exists(): if cover_mode: logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) os.remove(str(to_save_name)) else: logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) has_continue = True continue results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger) torch.save(results, to_save_name) logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}'.format(time_string(), i, len(to_evaluate_indexes), index, len(nets), seeds, to_save_name)) # measure elapsed time if not has_continue: epoch_time.update(time.time() - start_time) start_time = time.time() need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True)) logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True))) logger.log('{:}'.format('*' * 100)) logger.log('{:} {:74s} {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len( to_evaluate_indexes), index, len(nets), need_time), '*' * 10)) logger.log('{:}'.format('*' * 100)) logger.close()
def procedure( xloader, network, criterion, optimizer, metric, mode: Text, logger_fn: Callable = None, ): data_time, batch_time = AverageMeter(), AverageMeter() if mode.lower() == "train": network.train() elif mode.lower() == "valid": network.eval() else: raise ValueError("The mode is not right : {:}".format(mode)) end = time.time() for i, (inputs, targets) in enumerate(xloader): # measure data loading time data_time.update(time.time() - end) # calculate prediction and loss if mode == "train": optimizer.zero_grad() outputs = network(inputs) targets = targets.to(get_device(outputs)) if mode == "train": loss = criterion(outputs, targets) loss.backward() optimizer.step() # record with torch.no_grad(): results = metric(outputs, targets) # measure elapsed time batch_time.update(time.time() - end) end = time.time() return metric.get_info()
def valid_func(model, val_loader, criterion): model.eval() val_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch y_pred = model(x) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) print("Val loss: {}".format(val_meter.avg)) return val_meter
def train_shared_cnn(xloader, shared_cnn, criterion, scheduler, optimizer, print_freq, logger, config, start_epoch): # start training start_time, epoch_time, total_epoch = time.time(), AverageMeter( ), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Traing the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(scheduler.get_lr()))) data_time, batch_time = AverageMeter(), AverageMeter() losses, top1s, top5s, xend = AverageMeter(), AverageMeter( ), AverageMeter(), time.time() shared_cnn.train() for step, (inputs, targets) in enumerate(xloader): scheduler.update(None, 1.0 * step / len(xloader)) targets = targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - xend) optimizer.zero_grad() _, logits = shared_cnn(inputs) loss = criterion(logits, targets) loss.backward() torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5) optimizer.step() # record prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1s.update(prec1.item(), inputs.size(0)) top5s.update(prec5.item(), inputs.size(0)) # measure elapsed time batch_time.update(time.time() - xend) xend = time.time() if step % print_freq == 0 or step + 1 == len(xloader): Sstr = '*Train-Shared-CNN* ' + time_string( ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format( batch_time=batch_time, data_time=data_time) Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format( loss=losses, top1=top1s, top5=top5s) logger.log(Sstr + ' ' + Tstr + ' ' + Wstr) cnn_loss, cnn_top1, cnn_top5 = losses.avg, top1s.avg, top5s.avg logger.log( '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, cnn_loss, cnn_top1, cnn_top5)) epoch_time.update(time.time() - start_time) start_time = time.time() return
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger): data_time, batch_time = AverageMeter(), AverageMeter() base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter() arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter() network.train() end = time.time() for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader): # print(111111111111111111111) # print(arch_inputs.size()) # print(arch_targets.size()) scheduler.update(None, 1.0 * step / len(xloader)) base_targets = base_targets.cuda(non_blocking=True) arch_targets = arch_targets.cuda(non_blocking=True) # measure data loading time data_time.update(time.time() - end) # update the architecture-weight a_optimizer.zero_grad() arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) a_optimizer.step() # record arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 2)) arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_top1.update (arch_prec1.item(), arch_inputs.size(0)) arch_top5.update (arch_prec5.item(), arch_inputs.size(0)) # update the weights w_optimizer.zero_grad() _, logits = network(base_inputs) base_loss = criterion(logits, base_targets) base_loss.backward() torch.nn.utils.clip_grad_norm_(network.parameters(), 5) w_optimizer.step() # record base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 2)) base_losses.update(base_loss.item(), base_inputs.size(0)) base_top1.update (base_prec1.item(), base_inputs.size(0)) base_top5.update (base_prec5.item(), base_inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step + 1 == len(xloader): Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader)) # Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Wstr = 'Base [Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=base_losses, top1=base_top1, top5=base_top5) Astr = 'Arch [Loss {loss.avg:.3f} Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5) logger.log(Sstr + ' ' + Wstr + ' ' + Astr) return base_losses.avg, base_top1.avg, base_top5.avg
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True #torch.backends.cudnn.benchmark = True torch.set_num_threads( workers ) save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells']) logger = Logger(str(save_dir), 0, False) if model_str in CellArchitectures: arch = CellArchitectures[model_str] logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str)) else: try: arch = CellStructure.str2structure(model_str) except: raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str)) assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch) logger.log('Start train-evaluate {:}'.format(arch.tostr())) logger.log('arch_config : {:}'.format(arch_config)) start_time, seed_time = time.time(), AverageMeter() for _is, seed in enumerate(seeds): logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed)) to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed) if to_save_name.exists(): logger.log('Find the existing file {:}, directly load!'.format(to_save_name)) checkpoint = torch.load(to_save_name) else: logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name)) checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger) torch.save(checkpoint, to_save_name) # log information logger.log('{:}'.format(checkpoint['info'])) all_dataset_keys = checkpoint['all_dataset_keys'] for dataset_key in all_dataset_keys: logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15)) dataset_info = checkpoint[dataset_key] #logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param'])) logger.log('config : {:}'.format(dataset_info['config'])) logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train'])) last_epoch = dataset_info['total_epoch'] - 1 train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es'] valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es'] logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch])) # measure elapsed time seed_time.update(time.time() - start_time) start_time = time.time() need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) ) logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time)) logger.close()
def eval_robust_heatmap(detector, xloader, print_freq, logger): batch_time, NUM_PTS = AverageMeter(), xloader.dataset.NUM_PTS Preds, GT_locs, Distances = [], [], [] eval_meta, end = Eval_Meta(), time.time() with torch.no_grad(): detector.eval() for i, (inputs, heatmaps, masks, norm_points, thetas, data_index, nopoints, xshapes) in enumerate(xloader): data_index = data_index.squeeze(1).tolist() batch_size, iters, C, H, W = inputs.size() for ibatch in range(batch_size): xinputs, xpoints, xthetas = inputs[ibatch], norm_points[ ibatch].permute(0, 2, 1).contiguous(), thetas[ibatch] batch_features, batch_heatmaps, batch_locs, batch_scos = detector( xinputs.cuda(non_blocking=True)) batch_locs = batch_locs.cpu()[:, :-1] all_locs = [] for _iter in range(iters): _locs = normalize_points((H, W), batch_locs[_iter].permute(1, 0)) xlocs = torch.cat((_locs, torch.ones(1, NUM_PTS)), dim=0) nlocs = torch.mm(xthetas[_iter, :2], xlocs) rlocs = denormalize_points(xshapes[ibatch].tolist(), nlocs) rlocs = torch.cat( (rlocs.permute(1, 0), xpoints[_iter, :, 2:]), dim=1) all_locs.append(rlocs.clone()) GT_loc = xloader.dataset.labels[ data_index[ibatch]].get_points() norm_distance = xloader.dataset.get_normalization_distance( data_index[ibatch]) # save the results eval_meta.append((sum(all_locs) / len(all_locs)).numpy().T, GT_loc.numpy(), xloader.dataset.datas[data_index[ibatch]], norm_distance) Distances.append(norm_distance) Preds.append(all_locs) GT_locs.append(GT_loc.permute(1, 0)) # compute time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 or i + 1 == len(xloader): last_time = convert_secs2time( batch_time.avg * (len(xloader) - i - 1), True) logger.log( ' -->>[Robust HEATMAP-based Evaluation] [{:03d}/{:03d}] Time : {:}' .format(i, len(xloader), last_time)) # evaluate the results errors, valids = calculate_robust(Preds, GT_locs, Distances, NUM_PTS) return errors, valids, eval_meta
def train_func(xargs, search_loader, valid_loader, network, operations, criterion, w_scheduler, w_optimizer, logger, drop_iter, total_epoch): logger.log('|=> Train, drop_iter={}, epochs={}'.format( drop_iter, total_epoch)) # start training start_time, search_time, epoch_time, start_epoch = time.time( ), AverageMeter(), AverageMeter(), 0 for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format( epoch_str, need_time, min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \ = search_func(search_loader, network, operations, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
def valid_func(model, dset_val, criterion, print_results=True): model.eval() val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32) val_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch y_pred = model(x) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) if print_results: print("Val loss: {}".format(val_meter.avg)) return val_meter
def valid_func(model, dset_val, criterion, device = 'cuda' if torch.cuda.is_available() else 'cpu', print_results=True): model.eval() val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32) val_meter = AverageMeter() val_acc_meter = AverageMeter() with torch.no_grad(): for batch in val_loader: x, y = batch x = x.to(device) y = y.to(device) y_pred = model(x) if isinstance(criterion, torch.nn.CrossEntropyLoss): predicted = torch.argmax(y_pred, dim=1) correct = torch.sum((predicted == y)).item() total = predicted.size()[0] val_acc_meter.update(correct/total) val_loss = criterion(y_pred, y) val_meter.update(val_loss.item()) if print_results: print("Val loss: {}, Val acc: {}".format(val_meter.avg, val_acc_meter.avg if val_acc_meter.avg > 0 else "Not applicable")) return val_meter
def train_normal(num_epochs, model, dset_train, batch_size, grad_clip, logging_freq, optim="sgd", **kwargs): train_loader = torch.utils.data.DataLoader(dset_train, batch_size=batch_size, shuffle=True) model.train() for epoch in range(num_epochs): epoch_loss = AverageMeter() for batch_idx, batch in enumerate(train_loader): x, y = batch w_optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) loss.backward(retain_graph=True) epoch_loss.update(loss.item()) if optim == "newton": linear_weight = list(model.weight_params())[0] hessian_newton = torch.inverse( hessian(loss * 1, linear_weight, linear_weight).reshape(linear_weight.size()[1], linear_weight.size()[1])) with torch.no_grad(): for w in model.weight_params(): w = w.subtract_(torch.matmul(w.grad, hessian_newton)) elif optim == "sgd": torch.nn.utils.clip_grad_norm_(model.weight_params(), 1) w_optimizer.step() else: raise NotImplementedError wandb.log({ "Train loss": epoch_loss.avg, "Epoch": epoch, "Batch": batch_idx }) if batch_idx % logging_freq == 0: print("Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format( epoch, batch_idx, epoch_loss.avg, model.fc1.alphas.data))
def check_files(save_dir, meta_file, basestr): meta_infos = torch.load(meta_file, map_location='cpu') meta_archs = meta_infos['archs'] meta_num_archs = meta_infos['total'] meta_max_node = meta_infos['max_node'] assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs)) sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr)))) print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs))) subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 num_seeds = defaultdict(lambda: 0) for index, sub_dir in enumerate(sub_model_dirs): xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth')) #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) arch_indexes = set() for checkpoint in xcheckpoints: temp_names = checkpoint.name.split('-') assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name) arch_indexes.add( temp_names[1] ) subdir2archs[sub_dir] = sorted(list(arch_indexes)) num_evaluated_arch += len(arch_indexes) # count number of seeds for each architecture for arch_index in arch_indexes: num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1 print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items()))) for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key)) dir2ckps, dir2ckp_exists = dict(), dict() start_time, epoch_time = time.time(), AverageMeter() for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): seeds = [777, 888, 999] numrs = defaultdict(lambda: 0) all_checkpoints, all_ckp_exists = [], [] for arch_index in arch_indexes: checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds] ckp_exists = [(sub_dir/x).exists() for x in checkpoints] arch_index = int(arch_index) assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index) all_checkpoints += checkpoints all_ckp_exists += ckp_exists numrs[sum(ckp_exists)] += 1 dir2ckps[ str(sub_dir) ] = all_checkpoints dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists # measure time epoch_time.update(time.time() - start_time) start_time = time.time() numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] ) print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True #torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True torch.set_num_threads( workers ) assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange) if use_less: sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) else: sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells']) logger = Logger(str(sub_dir), 0, False) all_archs = meta_info['archs'] assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total']) assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1]) if arch_index == -1: to_evaluate_indexes = list(range(srange[0], srange[1]+1)) else: to_evaluate_indexes = [arch_index] logger.log('xargs : seeds = {:}'.format(seeds)) logger.log('xargs : arch_index = {:}'.format(arch_index)) logger.log('xargs : cover_mode = {:}'.format(cover_mode)) logger.log('-'*100) logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode)) for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split)) logger.log('--->>> architecture config : {:}'.format(arch_config)) start_time, epoch_time = time.time(), AverageMeter() for i, index in enumerate(to_evaluate_indexes): arch = all_archs[index] logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15)) #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15)) logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15)) # test this arch on different datasets with different seeds has_continue = False for seed in seeds: to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed) if to_save_name.exists(): if cover_mode: logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name)) os.remove(str(to_save_name)) else : logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name)) has_continue = True continue results = evaluate_all_datasets(CellStructure.str2structure(arch), \ datasets, xpaths, splits, use_less, seed, \ arch_config, workers, logger) torch.save(results, to_save_name) logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name)) # measure elapsed time if not has_continue: epoch_time.update(time.time() - start_time) start_time = time.time() need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) ) logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) )) logger.log('{:}'.format('*'*100)) logger.log('{:} {:74s} {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10)) logger.log('{:}'.format('*'*100)) logger.close()
def main(xargs): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads(xargs.workers) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) train_data, valid_data, xshape, class_num = get_datasets( xargs.dataset, xargs.data_path, -1) #config_path = 'configs/nas-benchmark/algos/GDAS.config' config = load_config(xargs.config_path, { 'class_num': class_num, 'xshape': xshape }, logger) search_loader, _, valid_loader = get_nas_search_loaders( train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers) logger.log( '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format( xargs.dataset, len(search_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format( xargs.dataset, config)) search_space = get_search_spaces('cell', xargs.search_space_name) if xargs.model_config is None: model_config = dict2config( { 'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) else: model_config = load_config( xargs.model_config, { 'num_classes': class_num, 'space': search_space, 'affine': False, 'track_running_stats': bool(xargs.track_running_stats) }, None) search_model = get_cell_based_tiny_net(model_config) logger.log('search-model :\n{:}'.format(search_model)) logger.log('model-config : {:}'.format(model_config)) w_optimizer, w_scheduler, criterion = get_optim_scheduler( search_model.get_weights(), config) a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('criterion : {:}'.format(criterion)) flop, param = get_model_infos(search_model, xshape) logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param)) logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space)) if xargs.arch_nas_dataset is None: api = None else: api = API(xargs.arch_nas_dataset) logger.log('{:} create API = {:} done'.format(time_string(), api)) last_info, model_base_path, model_best_path = logger.path( 'info'), logger.path('model'), logger.path('best') network, criterion = torch.nn.DataParallel( search_model).cuda(), criterion.cuda() if last_info.exists(): # automatically resume from previous checkpoint logger.log("=> loading checkpoint of the last-info '{:}' start".format( last_info)) last_info = torch.load(last_info) start_epoch = last_info['epoch'] checkpoint = torch.load(last_info['last_checkpoint']) genotypes = checkpoint['genotypes'] valid_accuracies = checkpoint['valid_accuracies'] search_model.load_state_dict(checkpoint['search_model']) w_scheduler.load_state_dict(checkpoint['w_scheduler']) w_optimizer.load_state_dict(checkpoint['w_optimizer']) a_optimizer.load_state_dict(checkpoint['a_optimizer']) logger.log( "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch." .format(last_info, start_epoch)) else: logger.log("=> do not find the last-info file : {:}".format(last_info)) start_epoch, valid_accuracies, genotypes = 0, { 'best': -1 }, { -1: search_model.genotype() } # start training start_time, search_time, epoch_time, total_epoch = time.time( ), AverageMeter(), AverageMeter(), config.epochs + config.warmup for epoch in range(start_epoch, total_epoch): w_scheduler.update(epoch, 0.0) need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch - epoch), True)) epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch) search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)) logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format( epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr()))) search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \ = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger) search_time.update(time.time() - start_time) logger.log( '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s' .format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log( '[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%' .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5)) # check the best accuracy valid_accuracies[epoch] = valid_a_top1 if valid_a_top1 > valid_accuracies['best']: valid_accuracies['best'] = valid_a_top1 genotypes['best'] = search_model.genotype() find_best = True else: find_best = False genotypes[epoch] = search_model.genotype() logger.log('<<<--->>> The {:}-th epoch : {:}'.format( epoch_str, genotypes[epoch])) # save checkpoint save_path = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(xargs), 'search_model': search_model.state_dict(), 'w_optimizer': w_optimizer.state_dict(), 'a_optimizer': a_optimizer.state_dict(), 'w_scheduler': w_scheduler.state_dict(), 'genotypes': genotypes, 'valid_accuracies': valid_accuracies }, model_base_path, logger) last_info = save_checkpoint( { 'epoch': epoch + 1, 'args': deepcopy(args), 'last_checkpoint': save_path, }, logger.path('info'), logger) if find_best: logger.log( '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.' .format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-' * 100) # check the performance from the architecture dataset logger.log( 'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format( total_epoch, search_time.sum, genotypes[total_epoch - 1])) if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch - 1], '200'))) logger.close()