def get_cifar_tuned_model(load_weights=True): network = NetworkCIFAR(40, CIFAR_CLASSES, 20, True, 0.4, CIFAR_TUNED) if load_weights: device = torch.device('cpu') state_dict = torch.load('weights/cifar_tuned.pt', map_location=device) network.load_state_dict(state_dict) return network
def random_generate(self): num_skip_connect = SearchControllerConf['random_search']['num_identity'] num_arch = SearchControllerConf['random_search']['num_arch'] flops_threshold = SearchControllerConf['random_search']['flops_threshold'] """Random generate the architecture""" # k = 2 + 3 + 4 + 5 = 14 k = sum(1 for i in range(self._steps) for n in range(2+i)) num_ops = len(PRIMITIVES) self.random_arch_list = [] for ai in range(num_arch): seed = random.randint(0, 1000) torch.manual_seed(seed) while True: self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=False) self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=False) arch = self.genotype() # if the skip connect meet num_skip_connect op_names, indices = zip(*arch.normal) cnt = 0 for name, index in zip(op_names, indices): if name == 'skip_connect': cnt += 1 if cnt == num_skip_connect: # the flops threshold model = NetworkCIFAR(36, 10, 20, True, arch, False) flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),), verbose=False) if flops / 1e6 >= flops_threshold: self.random_arch_list += [('arch_' + str(ai), arch)] break else: continue return self.random_arch_list
def process_logs(args) -> DataFrame: data = [] for log in args.train: row = [] try: # evaluation stage metrics lines = str(log.readlines()) match = re.search(r"arch='(?P<name>.*?)'", lines) name = match.group("name") row.append(name) # l2_loss_2e01 -> 2e-01 weight_value = float(name.split("_")[2].replace("e", "e-")) row.append(weight_value) match = re.search(r"param size.*?(?P<value>\d*\.\d+)MB", lines) param_size = float(match.group("value")) row.append(param_size) for metric in [ TRAIN_LOSS, TRAIN_ACC, VALID_LOSS, VALID_ACC, TEST_LOSS, TEST_ACC ]: value = float( re.findall(rf'{metric}(?:uracy)? (?P<value>\d*\.\d+)', lines)[-1]) row.append(value) except Exception as e: print(f"Error '{e}' while processing file {log.name}") while len(row) < 9: row.append(None) try: # search stage metrics genotype = genotypes.__dict__[name] genotype_str = str(genotype) match = False for s_log in args.search: s_lines = str(s_log.readlines()) s_log.seek(0, 0) # ((?!\\n).)* = anything except new line escaped match = re.search( r"stats = (?P<stats>{((?!\\n).)*" + re.escape(genotype_str) + r".*?})\\n\",", s_lines) if match: stats = eval(match.group("stats")) # L2 loss case if list(stats.get(L1_LOSS).keys())[0][0] == -1: LOSS = L2_LOSS # L1 loss case elif list(stats.get(L2_LOSS).keys())[0][0] == -1: LOSS = L1_LOSS else: raise Exception("L1 and L2 loss have w = -1") values = list(stats.get(LOSS).values())[0] search_criterion_loss = values[CRITERION_LOSS] search_reg_loss = values[REG_LOSS] row.append(search_criterion_loss) row.append(search_reg_loss) search_acc = values[VALID_ACC] row.append(search_acc) break if not match: raise Exception(f"Didn't find {name} on eval logs") except Exception as e: print(f"Error '{e}' while processing file {log.name}") while len(row) < 12: row.append(None) try: # model profiling genotype = genotypes.__dict__[name] match = re.search(r"init_channels=(?P<value>\d+)", lines) init_channels = int(match.group("value")) match = re.search(r"layers=(?P<value>\d+)", lines) layers = int(match.group("value")) match = re.search(r"drop_path_prob=(?P<value>\d+\.\d+)", lines) drop_path_prob = float(match.group("value")) match = re.search(r"auxiliary=(?P<value>\w+)", lines) auxiliary = bool(match.group("value")) model = NetworkCIFAR(init_channels, 10, layers, auxiliary, genotype) model.cuda() model.drop_path_prob = drop_path_prob parameters, net_flops, total_time_gpu, total_time_cpu = model_profiling( model, name) row.append(parameters) row.append(net_flops) row.append(total_time_gpu) row.append(total_time_cpu) except Exception as e: print(f"Error '{e}' while processing file {log.name}") if len(row) > 0: data.append(row) df = pd.DataFrame(data, columns=[ MODEL_NAME, WEIGHT, PARAMETERS_DARTS, TRAIN_LOSS, TRAIN_ACC, VALID_LOSS, VALID_ACC, TEST_LOSS, TEST_ACC, SEARCH_CRIT_LOSS, SEARCH_REG_LOSS, SEARCH_ACC, PARAMETERS_OFA, FLOPS, LATENCY_GPU, LATENCY_CPU ]) df.set_index(keys=MODEL_NAME, inplace=True) df.sort_values(by=WEIGHT, inplace=True, ascending=False) pd.set_option("display.max_rows", None, "display.max_columns", None, "display.width", None) print(df) df.to_csv(args.output) return df
def train_arch(stage, step, valid_queue, model, optimizer_a, cur_switches_normal, cur_switches_reduce ): global best_prec1 global best_normal_indices global best_reduce_indices # for step in range(100): try: input_search, target_search = next(valid_queue_iter) except: valid_queue_iter = iter(valid_queue) input_search, target_search = next(valid_queue_iter) input_search = input_search.cuda() target_search = target_search.cuda(non_blocking=True) normal_grad_buffer = [] reduce_grad_buffer = [] reward_buffer = [] params_buffer = [] flops_list = [] params_list = [] # cifar_mu = np.ones((3, 32, 32)) # cifar_mu[0, :, :] = 0.4914 # cifar_mu[1, :, :] = 0.4822 # cifar_mu[2, :, :] = 0.4465 # (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # cifar_std = np.ones((3, 32, 32)) # cifar_std[0, :, :] = 0.2471 # cifar_std[1, :, :] = 0.2435 # cifar_std[2, :, :] = 0.2616 # criterion = nn.CrossEntropyLoss() # optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # classifier = PyTorchClassifier( # model=model, # clip_values=(0.0, 1.0), # preprocessing=(cifar_mu, cifar_std), # loss=criterion, # optimizer=optimizer, # input_shape=(3, 32, 32), # nb_classes=10, # ) for batch_idx in range(model.module.rl_batch_size): # 多采集几个网络,测试 # sample the submodel # if stage == 1: # print("ok") normal_indices, reduce_indices, genotype = get_cur_model(model, cur_switches_normal, cur_switches_reduce) # return 0.0, 0.0 # attack = FastGradientMethod(estimator=model, eps=0.2) # x_test_adv = attack.generate(x=x_test) # res = clever_u(classifier,valid_queue.dataset.data[-1].transpose(2,0,1) , 2, 2, R_LI, norm=np.inf, pool_factor=3) # print(res) # validat the sub_model with torch.no_grad(): logits= model(input_search) prec1, _ = utils.accuracy(logits, target_search, topk=(1,5)) sub_model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, 20, False, genotype) sub_model.drop_path_prob = 0 # para0 = utils.count_parameters_in_MB(sub_model) input = torch.randn(1,3,32,32) flops, params = profile(sub_model, inputs = (input,), ) flops_s, params_s = clever_format([flops, params], "%.3f") flops, params = flops/1e9, params/1e6 params_buffer.append(params) flops_list.append(flops_s) params_list.append(params_s) # prec1 = np.random.rand() if model.module._arch_parameters[0].grad is not None: model.module._arch_parameters[0].grad.data.zero_() if model.module._arch_parameters[1].grad is not None: model.module._arch_parameters[1].grad.data.zero_() obj_term = 0 for i in range(14): obj_term = obj_term + model.module.normal_log_prob[i] obj_term = obj_term + model.module.reduce_log_prob[i] loss_term = -obj_term # backward loss_term.backward() # take out gradient dict normal_grad_buffer.append(model.module._arch_parameters[0].grad.data.clone()) reduce_grad_buffer.append(model.module._arch_parameters[1].grad.data.clone()) reward = calculate_reward(stage, prec1, params) reward_buffer.append(reward) # recode best_reward index if prec1 > best_prec1: best_prec1 = prec1 best_normal_indices = normal_indices best_reduce_indices = reduce_indices # else: # best_normal_indices = [] # best_reduce_indices = [] logging.info(flops_list) logging.info(params_list) logging.info(normal_indices.detach().cpu().numpy().squeeze()) logging.info(reduce_indices.detach().cpu().numpy().squeeze()) logging.info(genotype) avg_reward = sum(reward_buffer) / model.module.rl_batch_size avg_params = sum(params_buffer) / model.module.rl_batch_size if model.module.baseline == 0: model.module.baseline = avg_reward else: model.module.baseline += model.module.baseline_decay_weight * (avg_reward - model.module.baseline) # hs # model.module.baseline = model.module.baseline_decay_weight * model.module.baseline + \ # (1-model.module.baseline_decay_weight) * avg_reward model.module._arch_parameters[0].grad.data.zero_() model.module._arch_parameters[1].grad.data.zero_() for j in range(model.module.rl_batch_size): model.module._arch_parameters[0].grad.data += (reward_buffer[j] - model.module.baseline) * normal_grad_buffer[j] model.module._arch_parameters[1].grad.data += (reward_buffer[j] - model.module.baseline) * reduce_grad_buffer[j] model.module._arch_parameters[0].grad.data /= model.module.rl_batch_size model.module._arch_parameters[1].grad.data /= model.module.rl_batch_size # if step % 50 == 0: # logging.info(model.module._arch_parameters[0].grad.data) # logging.info(model.module._arch_parameters[0]) # apply gradients nn.utils.clip_grad_norm_(model.module.arch_parameters(), args.grad_clip) optimizer_a.step() if step % args.report_freq == 0: # logging.info(model.module._arch_parameters[0]) # valid the argmax arch logging.info('REINFORCE [step %d]\t\tMean Reward %.4f\tBaseline %.4f\tBest Sampled Prec1 %.4f', step, avg_reward, model.module.baseline, best_prec1) max_normal_index, max_reduce_index = set_max_model(model, cur_switches_normal, cur_switches_reduce) logits= model(input_search) prec1, _ = utils.accuracy(logits, target_search, topk=(1,5)) logging.info('REINFORCE [step %d]\t\tCurrent Max Architecture Reward %.4f\t\tAvarage Params %.3f', step, prec1/100, avg_params) max_arch_reward_writer.add_scalar('max_arch_reward_{}'.format(stage), prec1, tb_index[stage]) avg_params_writer.add_scalar('avg_params_{}'.format(stage), avg_params, tb_index[stage]) logging.info(max_normal_index) logging.info(max_reduce_index) best_reward_arch_writer.add_scalar('best_prec1_arch_{}'.format(stage), best_prec1, tb_index[stage]) logging.info(np.around(torch.Tensor(reward_buffer).numpy(),3)) # logging.info(model.module.normal_probs) # logging.info(model.module.reduce_probs) logging.info(model.module.alphas_normal) logging.info(model.module.alphas_reduce) for i in range(14): normal_max_writer[i].add_scalar('normal_max_arch_{}'.format(stage), np.argmax(model.module.normal_probs.detach().cpu()[i].numpy()), tb_index[stage]) normal_min_k = get_min_k(model.module.normal_probs.detach().cpu()[i].numpy(), normal_num_to_drop[stage]) for j in range(normal_num_to_drop[stage]): normal_min_writer[i][j].add_scalar('normal_min_arch_{}_{}'.format(stage, j), normal_min_k[j], tb_index[stage]) best_normal_writer[i].add_scalar('best_normal_index_{}'.format(stage), best_normal_indices[i].cpu().numpy(), tb_index[stage]) for i in range(14): reduce_max_writer[i].add_scalar('reduce_max_arch_{}'.format(stage), np.argmax(model.module.reduce_probs.detach().cpu()[i].numpy()), tb_index[stage]) reduce_min_k = get_min_k(model.module.reduce_probs.detach().cpu()[i].numpy(), reduce_num_to_drop[stage]) for j in range(reduce_num_to_drop[stage]): reduce_min_writer[i][j].add_scalar('reduce_min_arch_{}_{}'.format(stage, j), reduce_min_k[j], tb_index[stage]) best_reduce_writer[i].add_scalar('best_reduce_index_{}'.format(stage), best_reduce_indices[i].cpu().numpy(), tb_index[stage]) best_prec1 = 0 tb_index[stage]+=1 model.module.restore_super_net()
def process_logs(args) -> DataFrame: # filter logs lines = str(args.log.readlines()) match = re.search(r"Selected regularization(?P<reg>.*?)\\n", lines) reg = match.group("reg") if L1_LOSS in reg: loss = L1_LOSS elif L2_LOSS in reg: loss = L2_LOSS else: raise RuntimeError("Cant decode line Selected regularization") match = re.finditer(r"hist = (?P<hist>.*?)\\n", lines) hist_str = list(match)[-1].group("hist") hist = eval(hist_str)[loss] print("Removing non-optimal samples") # filter out dominated points filter_hist(hist) data = [] for weight, result in hist.items(): row = [] name = create_genotype_name(weight, loss) try: row.append(name) weight_value = weight[0] row.append(weight_value) # {'train_acc': 25.035999994506835, # 'valid_acc': 20.171999999084473, # 'reg_loss': 16.01249122619629, # 'criterion_loss': 1.9922981262207031, # 'model_size': 1.81423, # 'genotype': Genotype(..)} for metric in [ SIZE, TRAIN_ACC, VALID_ACC, CRITERION_LOSS, REG_LOSS ]: row.append(result[metric]) except Exception as e: print(f"Error '{e}' while processing file {args.log} w={weight}") while len(row) < 7: row.append(None) try: # model profiling genotype = result[GENOTYPE] # using default from train.py for CIFAR10 model = NetworkCIFAR(36, 10, 20, False, genotype) model.cuda() model.drop_path_prob = 0.3 parameters, net_flops, total_time_gpu, total_time_cpu = model_profiling( model, name) row.append(parameters) row.append(net_flops) row.append(total_time_gpu) row.append(total_time_cpu) except Exception as e: print(f"Error '{e}' while processing file {args.log} w={weight}") raise e if len(row) > 0: data.append(row) df = pd.DataFrame(data, columns=[ MODEL_NAME, WEIGHT, "Params", SEARCH_TRAIN_ACC, SEARCH_VAL_ACC, SEARCH_CRIT_LOSS, SEARCH_REG_LOSS, "Parameters", FLOPS, "Latency GPU", LATENCY_CPU ]) df.set_index(keys=MODEL_NAME, inplace=True) df.sort_values(by=WEIGHT, inplace=True, ascending=False) pd.set_option("display.max_rows", None, "display.max_columns", None, "display.width", None) print(df) df.to_csv(args.output) return df
def main(): parser = argparse.ArgumentParser("Common Argument Parser") parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') parser.add_argument('--dataset', type=str, default='cifar10', help='which dataset:\ cifar10, mnist, emnist, fashion, svhn, stl10, devanagari') parser.add_argument('--batch_size', type=int, default=64, help='batch size') parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') parser.add_argument('--learning_rate_min', type=float, default=1e-8, help='min learning rate') parser.add_argument('--lr_power_annealing_exponent_order', type=float, default=2, help='Cosine Power Annealing Schedule Base, larger numbers make ' 'the exponential more dominant, smaller make cosine more dominant, ' '1 returns to standard cosine annealing.') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') parser.add_argument('--weight_decay', '--wd', dest='weight_decay', type=float, default=3e-4, help='weight decay') parser.add_argument('--partial', default=1/8, type=float, help='partially adaptive parameter p in Padam') parser.add_argument('--report_freq', type=float, default=50, help='report frequency') parser.add_argument('--gpu', type=int, default=0, help='gpu device id') parser.add_argument('--epochs', type=int, default=2000, help='num of training epochs') parser.add_argument('--start_epoch', default=1, type=int, metavar='N', help='manual epoch number (useful for restarts)') parser.add_argument('--warmup_epochs', type=int, default=5, help='num of warmup training epochs') parser.add_argument('--warm_restarts', type=int, default=20, help='warm restarts of cosine annealing') parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') parser.add_argument('--mid_channels', type=int, default=32, help='C_mid channels in choke SharpSepConv') parser.add_argument('--layers', type=int, default=20, help='total number of layers') parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') parser.add_argument('--mixed_auxiliary', action='store_true', default=False, help='Learn weights for auxiliary networks during training. Overrides auxiliary flag') parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') parser.add_argument('--autoaugment', action='store_true', default=False, help='use cifar10 autoaugment https://arxiv.org/abs/1805.09501') parser.add_argument('--random_eraser', action='store_true', default=False, help='use random eraser') parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') parser.add_argument('--save', type=str, default='EXP', help='experiment name') parser.add_argument('--seed', type=int, default=0, help='random seed') parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') parser.add_argument('--ops', type=str, default='OPS', help='which operations to use, options are OPS and DARTS_OPS') parser.add_argument('--primitives', type=str, default='PRIMITIVES', help='which primitive layers to use inside a cell search space,' ' options are PRIMITIVES, SHARPER_PRIMITIVES, and DARTS_PRIMITIVES') parser.add_argument('--optimizer', type=str, default='sgd', help='which optimizer to use, options are padam and sgd') parser.add_argument('--load', type=str, default='', metavar='PATH', help='load weights at specified location') parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') parser.add_argument('--flops', action='store_true', default=False, help='count flops and exit, aka floating point operations.') parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, metavar='PATH', default='', help='evaluate model at specified path on training, test, and validation datasets') parser.add_argument('--multi_channel', action='store_true', default=False, help='perform multi channel search, a completely separate search space') parser.add_argument('--load_args', type=str, default='', metavar='PATH', help='load command line args from a json file, this will override ' 'all currently set args except for --evaluate, and arguments ' 'that did not exist when the json file was originally saved out.') parser.add_argument('--layers_of_cells', type=int, default=8, help='total number of cells in the whole network, default is 8 cells') parser.add_argument('--layers_in_cells', type=int, default=4, help='Total number of nodes in each cell, aka number of steps,' ' default is 4 nodes, which implies 8 ops') parser.add_argument('--weighting_algorithm', type=str, default='scalar', help='which operations to use, options are ' '"max_w" (1. - max_w + w) * op, and scalar (w * op)') # TODO(ahundt) remove final path and switch back to genotype parser.add_argument('--load_genotype', type=str, default=None, help='Name of genotype to be used') parser.add_argument('--simple_path', default=True, action='store_false', help='Final model is a simple path (MultiChannelNetworkModel)') args = parser.parse_args() args = utils.initialize_files_and_args(args) logger = utils.logging_setup(args.log_file_path) if not torch.cuda.is_available(): logger.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True torch.manual_seed(args.seed) cudnn.enabled=True torch.cuda.manual_seed(args.seed) logger.info('gpu device = %d' % args.gpu) logger.info("args = %s", args) DATASET_CLASSES = dataset.class_dict[args.dataset] DATASET_CHANNELS = dataset.inp_channel_dict[args.dataset] DATASET_MEAN = dataset.mean_dict[args.dataset] DATASET_STD = dataset.std_dict[args.dataset] logger.info('output channels: ' + str(DATASET_CLASSES)) # # load the correct ops dictionary op_dict_to_load = "operations.%s" % args.ops logger.info('loading op dict: ' + str(op_dict_to_load)) op_dict = eval(op_dict_to_load) # load the correct primitives list primitives_to_load = "genotypes.%s" % args.primitives logger.info('loading primitives:' + primitives_to_load) primitives = eval(primitives_to_load) logger.info('primitives: ' + str(primitives)) genotype = eval("genotypes.%s" % args.arch) # create the neural network criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() if args.multi_channel: final_path = None if args.load_genotype is not None: genotype = getattr(genotypes, args.load_genotype) print(genotype) if type(genotype[0]) is str: logger.info('Path :%s', genotype) # TODO(ahundt) remove final path and switch back to genotype cnn_model = MultiChannelNetwork( args.init_channels, DATASET_CLASSES, layers=args.layers_of_cells, criterion=criterion, steps=args.layers_in_cells, weighting_algorithm=args.weighting_algorithm, genotype=genotype) flops_shape = [1, 3, 32, 32] elif args.dataset == 'imagenet': cnn_model = NetworkImageNet(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) flops_shape = [1, 3, 224, 224] else: cnn_model = NetworkCIFAR(args.init_channels, DATASET_CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) flops_shape = [1, 3, 32, 32] cnn_model = cnn_model.cuda() logger.info("param size = %fMB", utils.count_parameters_in_MB(cnn_model)) if args.flops: logger.info('flops_shape = ' + str(flops_shape)) logger.info("flops = " + utils.count_model_flops(cnn_model, data_shape=flops_shape)) return optimizer = torch.optim.SGD( cnn_model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay ) # Get preprocessing functions (i.e. transforms) to apply on data train_transform, valid_transform = utils.get_data_transforms(args) if args.evaluate: # evaluate the train dataset without augmentation train_transform = valid_transform # Get the training queue, use full training and test set train_queue, valid_queue = dataset.get_training_queues( args.dataset, train_transform, valid_transform, args.data, args.batch_size, train_proportion=1.0, search_architecture=False) test_queue = None if args.dataset == 'cifar10': # evaluate best model weights on cifar 10.1 # https://github.com/modestyachts/CIFAR-10.1 test_data = cifar10_1.CIFAR10_1(root=args.data, download=True, transform=valid_transform) test_queue = torch.utils.data.DataLoader( test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8) if args.evaluate: # evaluate the loaded model, print the result, and return logger.info("Evaluating inference with weights file: " + args.load) eval_stats = evaluate( args, cnn_model, criterion, args.load, train_queue=train_queue, valid_queue=valid_queue, test_queue=test_queue) with open(args.stats_file, 'w') as f: arg_dict = vars(args) arg_dict.update(eval_stats) json.dump(arg_dict, f) logger.info("flops = " + utils.count_model_flops(cnn_model)) logger.info(utils.dict_to_log_string(eval_stats)) logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' + str(args.save)) return lr_schedule = cosine_power_annealing( epochs=args.epochs, max_lr=args.learning_rate, min_lr=args.learning_rate_min, warmup_epochs=args.warmup_epochs, exponent_order=args.lr_power_annealing_exponent_order) epochs = np.arange(args.epochs) + args.start_epoch # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) epoch_stats = [] stats_csv = args.epoch_stats_file stats_csv = stats_csv.replace('.json', '.csv') with tqdm(epochs, dynamic_ncols=True) as prog_epoch: best_valid_acc = 0.0 best_epoch = 0 best_stats = {} stats = {} epoch_stats = [] weights_file = os.path.join(args.save, 'weights.pt') for epoch, learning_rate in zip(prog_epoch, lr_schedule): # update the drop_path_prob augmentation cnn_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs # update the learning rate for param_group in optimizer.param_groups: param_group['lr'] = learning_rate # scheduler.get_lr()[0] train_acc, train_obj = train(args, train_queue, cnn_model, criterion, optimizer) val_stats = infer(args, valid_queue, cnn_model, criterion) stats.update(val_stats) stats['train_acc'] = train_acc stats['train_loss'] = train_obj stats['lr'] = learning_rate stats['epoch'] = epoch if stats['valid_acc'] > best_valid_acc: # new best epoch, save weights utils.save(cnn_model, weights_file) best_epoch = epoch best_stats.update(copy.deepcopy(stats)) best_valid_acc = stats['valid_acc'] best_train_loss = train_obj best_train_acc = train_acc # else: # # not best epoch, load best weights # utils.load(cnn_model, weights_file) logger.info('epoch, %d, train_acc, %f, valid_acc, %f, train_loss, %f, valid_loss, %f, lr, %e, best_epoch, %d, best_valid_acc, %f, ' + utils.dict_to_log_string(stats), epoch, train_acc, stats['valid_acc'], train_obj, stats['valid_loss'], learning_rate, best_epoch, best_valid_acc) stats['train_acc'] = train_acc stats['train_loss'] = train_obj epoch_stats += [copy.deepcopy(stats)] with open(args.epoch_stats_file, 'w') as f: json.dump(epoch_stats, f, cls=utils.NumpyEncoder) utils.list_of_dicts_to_csv(stats_csv, epoch_stats) # get stats from best epoch including cifar10.1 eval_stats = evaluate(args, cnn_model, criterion, weights_file, train_queue, valid_queue, test_queue) with open(args.stats_file, 'w') as f: arg_dict = vars(args) arg_dict.update(eval_stats) json.dump(arg_dict, f, cls=utils.NumpyEncoder) with open(args.epoch_stats_file, 'w') as f: json.dump(epoch_stats, f, cls=utils.NumpyEncoder) logger.info(utils.dict_to_log_string(eval_stats)) logger.info('Training of Final Model Complete! Save dir: ' + str(args.save))
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True if not args.random: # We would always get the same random architecture if we set the random # seed here. We'll set it after finding a random genotype. torch.manual_seed(args.seed) cudnn.enabled=True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) train_data, OUTPUT_DIM, IN_CHANNELS, is_regression = load_dataset(args, train=True) criterion = nn.CrossEntropyLoss() if not is_regression else nn.MSELoss() if args.random: model_tmp = Network(C=args.init_channels, num_classes=OUTPUT_DIM, layers=args.layers, primitives_name=primitives_name, criterion=criterion, num_channels=IN_CHANNELS) genotype = model_tmp.genotype() # Random # We can now set the random seed. torch.manual_seed(args.seed) # If the architecture was the default DATASET, look up the architecture corresponding to this dataset elif args.arch == 'DATASET': genotype = GENOTYPE_TBL[dataset] print(f'using genotype for {dataset}') else: try: genotype = eval("genotypes.%s" % args.arch) except (AttributeError, SyntaxError): genotype = genotypes.load_genotype_from_file(args.arch) genotypes.save_genotype_to_file(genotype, os.path.join(args.save, "genotype.arch")) # Set the inference network; default is NetworkCifar10; supported alternatives NetworkGalaxyZoo if dataset == 'GalaxyZoo' and args.gz_dtree: model = NetworkGalaxyZoo(C=args.init_channels, num_classes=OUTPUT_DIM, layers=args.layers, genotype=genotype, fc1_size=args.fc1_size, fc2_size=args.fc2_size, num_channels=IN_CHANNELS) else: model = NetworkCIFAR(args.init_channels, OUTPUT_DIM, args.layers, args.auxiliary, genotype, num_channels=IN_CHANNELS) model = model.cuda() logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) criterion = criterion.cuda() # build optimizer based on optimizer input; one of SGD or Adam if args.optimizer == 'SGD': optimizer = torch.optim.SGD( model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'Adam': optimizer= torch.optim.Adam( params=model.parameters(), lr=args.learning_rate, betas=(0.90, 0.999), weight_decay=args.weight_decay) else: raise ValueError(f"Bad optimizer; got {args.optimizer}, must be one of 'SGD' or 'Adam'.") # Split training data into training and validation queues num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) valid_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]), pin_memory=True, num_workers=2) # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) # history of training and validation loss; 2 columns for loss and accuracy / R2 hist_trn = np.zeros((args.epochs, 2)) hist_val = np.zeros((args.epochs, 2)) metric_name = 'accuracy' if not is_regression else 'R2' for epoch in range(args.epochs): # scheduler.step() # logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) logging.info('epoch %d lr %e', epoch, args.learning_rate) # model.drop_path_prob = args.drop_path_prob * epoch / args.epochs model.drop_path_prob = args.drop_path_prob # training results train_acc, train_obj = train(train_queue, model, criterion, optimizer, is_regression=is_regression) logging.info(f'training loss; {metric_name}: {train_obj:e} {train_acc:f}') # save history to numpy arrays hist_trn[epoch] = [train_acc, train_obj] np.save(os.path.join(args.save, 'hist_trn'), hist_trn) # validation results valid_acc, valid_obj = infer(valid_queue, model, criterion, is_regression=is_regression) logging.info(f'validation loss; {metric_name}: {valid_obj:e} {valid_acc:f}') # save history to numpy arrays hist_val[epoch] = [valid_acc, valid_obj] np.save(os.path.join(args.save, 'hist_val'), hist_val) # save current model weights utils.save(model, os.path.join(args.save, 'weights.pt'))
genotype=None) elif args.dataset == 'imagenet': cnn_model = NetworkImageNet(args.init_channels, classes, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) # workaround for graph generation limitations cnn_model.drop_path_prob = torch.zeros(1) else: cnn_model = NetworkCIFAR(args.init_channels, classes, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) # workaround for graph generation limitations cnn_model.drop_path_prob = torch.zeros(1) transforms = [ hl.transforms.Fold('MaxPool3x3 > Conv1x1 > BatchNorm', 'ResizableMaxPool', 'ResizableMaxPool'), hl.transforms.Fold('MaxPool > Conv > BatchNorm', 'ResizableMaxPool', 'ResizableMaxPool'), hl.transforms.Fold('Relu > Conv > Conv > BatchNorm', 'ReluSepConvBn'), hl.transforms.Fold('ReluSepConvBn > ReluSepConvBn', 'SharpSepConv', 'SharpSepConv'), hl.transforms.Fold('Relu > Conv > BatchNorm', 'ReLUConvBN'),
def main(): global best_top1, args, logger args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 # commented because it is now set as an argparse param. # args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank % torch.cuda.device_count() torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() # note the gpu is used for directory creation and log files # which is needed when run as multiple processes args = utils.initialize_files_and_args(args) logger = utils.logging_setup(args.log_file_path) if args.fp16: assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." if args.static_loss_scale != 1.0: if not args.fp16: logger.info( "Warning: if --fp16 is not used, static_loss_scale will be ignored." ) # # load the correct ops dictionary op_dict_to_load = "operations.%s" % args.ops logger.info('loading op dict: ' + str(op_dict_to_load)) op_dict = eval(op_dict_to_load) # load the correct primitives list primitives_to_load = "genotypes.%s" % args.primitives logger.info('loading primitives:' + primitives_to_load) primitives = eval(primitives_to_load) logger.info('primitives: ' + str(primitives)) # create model genotype = eval("genotypes.%s" % args.arch) # get the number of output channels classes = dataset.class_dict[args.dataset] # create the neural network if args.dataset == 'imagenet': model = NetworkImageNet(args.init_channels, classes, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) flops_shape = [1, 3, 224, 224] else: model = NetworkCIFAR(args.init_channels, classes, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels) flops_shape = [1, 3, 32, 32] model.drop_path_prob = 0.0 # if args.pretrained: # logger.info("=> using pre-trained model '{}'".format(args.arch)) # model = models.__dict__[args.arch](pretrained=True) # else: # logger.info("=> creating model '{}'".format(args.arch)) # model = models.__dict__[args.arch]() if args.flops: model = model.cuda() logger.info("param size = %fMB", utils.count_parameters_in_MB(model)) logger.info("flops_shape = " + str(flops_shape)) logger.info("flops = " + utils.count_model_flops(model, data_shape=flops_shape)) return if args.sync_bn: import apex logger.info("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda() if args.fp16: model = network_to_half(model) if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() # Scale learning rate based on global batch size args.learning_rate = args.learning_rate * float( args.batch_size * args.world_size) / 256. init_lr = args.learning_rate / args.warmup_lr_divisor optimizer = torch.optim.SGD(model.parameters(), init_lr, momentum=args.momentum, weight_decay=args.weight_decay) # epoch_count = args.epochs - args.start_epoch # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epoch_count)) # scheduler = warmup_scheduler.GradualWarmupScheduler( # optimizer, args.warmup_lr_divisor, args.warmup_epochs, scheduler) if args.fp16: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale) # Optionally resume from a checkpoint if args.resume or args.evaluate: if args.evaluate: args.resume = args.evaluate # Use a local scope to avoid dangling references def resume(): if os.path.isfile(args.resume): logger.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu)) args.start_epoch = checkpoint['epoch'] if 'best_top1' in checkpoint: best_top1 = checkpoint['best_top1'] model.load_state_dict(checkpoint['state_dict']) # An FP16_Optimizer instance's state dict internally stashes the master params. optimizer.load_state_dict(checkpoint['optimizer']) # TODO(ahundt) make sure scheduler loading isn't broken if 'lr_scheduler' in checkpoint: scheduler.load_state_dict(checkpoint['lr_scheduler']) elif 'lr_schedule' in checkpoint: lr_schedule = checkpoint['lr_schedule'] logger.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format( args.resume)) resume() # # Data loading code # traindir = os.path.join(args.data, 'train') # valdir = os.path.join(args.data, 'val') # if(args.arch == "inception_v3"): # crop_size = 299 # val_size = 320 # I chose this value arbitrarily, we can adjust. # else: # crop_size = 224 # val_size = 256 # train_dataset = datasets.ImageFolder( # traindir, # transforms.Compose([ # transforms.RandomResizedCrop(crop_size), # transforms.RandomHorizontalFlip(), # autoaugment.ImageNetPolicy(), # # transforms.ToTensor(), # Too slow, moved to data_prefetcher() # # normalize, # ])) # val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ # transforms.Resize(val_size), # transforms.CenterCrop(crop_size) # ])) # train_sampler = None # val_sampler = None # if args.distributed: # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) # val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) # train_loader = torch.utils.data.DataLoader( # train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), # num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) # val_loader = torch.utils.data.DataLoader( # val_dataset, # batch_size=args.batch_size, shuffle=False, # num_workers=args.workers, pin_memory=True, # sampler=val_sampler, # collate_fn=fast_collate) # Get preprocessing functions (i.e. transforms) to apply on data # normalize_as_tensor = False because we normalize and convert to a # tensor in our custom prefetching function, rather than as part of # the transform preprocessing list. train_transform, valid_transform = utils.get_data_transforms( args, normalize_as_tensor=False) # Get the training queue, select training and validation from training set train_loader, val_loader = dataset.get_training_queues( args.dataset, train_transform, valid_transform, args.data, args.batch_size, train_proportion=1.0, collate_fn=fast_collate, distributed=args.distributed, num_workers=args.workers) if args.evaluate: if args.dataset == 'cifar10': # evaluate best model weights on cifar 10.1 # https://github.com/modestyachts/CIFAR-10.1 train_transform, valid_transform = utils.get_data_transforms(args) # Get the training queue, select training and validation from training set # Get the training queue, use full training and test set train_queue, valid_queue = dataset.get_training_queues( args.dataset, train_transform, valid_transform, args.data, args.batch_size, train_proportion=1.0, search_architecture=False) test_data = cifar10_1.CIFAR10_1(root=args.data, download=True, transform=valid_transform) test_queue = torch.utils.data.DataLoader( test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) eval_stats = evaluate(args, model, criterion, train_queue=train_queue, valid_queue=valid_queue, test_queue=test_queue) with open(args.stats_file, 'w') as f: # TODO(ahundt) fix "TypeError: 1869 is not JSON serializable" to include arg info, see train.py # arg_dict = vars(args) # arg_dict.update(eval_stats) # json.dump(arg_dict, f) json.dump(eval_stats, f) logger.info("flops = " + utils.count_model_flops(model)) logger.info(utils.dict_to_log_string(eval_stats)) logger.info('\nEvaluation of Loaded Model Complete! Save dir: ' + str(args.save)) else: validate(val_loader, model, criterion, args) return lr_schedule = cosine_power_annealing( epochs=args.epochs, max_lr=args.learning_rate, min_lr=args.learning_rate_min, warmup_epochs=args.warmup_epochs, exponent_order=args.lr_power_annealing_exponent_order, restart_lr=args.restart_lr) epochs = np.arange(args.epochs) + args.start_epoch stats_csv = args.epoch_stats_file stats_csv = stats_csv.replace('.json', '.csv') with tqdm(epochs, dynamic_ncols=True, disable=args.local_rank != 0, leave=False) as prog_epoch: best_stats = {} stats = {} epoch_stats = [] best_epoch = 0 for epoch, learning_rate in zip(prog_epoch, lr_schedule): if args.distributed and train_loader.sampler is not None: train_loader.sampler.set_epoch(int(epoch)) # if args.distributed: # train_sampler.set_epoch(epoch) # update the learning rate for param_group in optimizer.param_groups: param_group['lr'] = learning_rate # scheduler.step() model.drop_path_prob = args.drop_path_prob * float(epoch) / float( args.epochs) # train for one epoch train_stats = train(train_loader, model, criterion, optimizer, int(epoch), args) if args.prof: break # evaluate on validation set top1, val_stats = validate(val_loader, model, criterion, args) stats.update(train_stats) stats.update(val_stats) # stats['lr'] = '{0:.5f}'.format(scheduler.get_lr()[0]) stats['lr'] = '{0:.5f}'.format(learning_rate) stats['epoch'] = epoch # remember best top1 and save checkpoint if args.local_rank == 0: is_best = top1 > best_top1 best_top1 = max(top1, best_top1) stats['best_top1'] = '{0:.3f}'.format(best_top1) if is_best: best_epoch = epoch best_stats = copy.deepcopy(stats) stats['best_epoch'] = best_epoch stats_str = utils.dict_to_log_string(stats) logger.info(stats_str) save_checkpoint( { 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_top1': best_top1, 'optimizer': optimizer.state_dict(), # 'lr_scheduler': scheduler.state_dict() 'lr_schedule': lr_schedule, 'stats': best_stats }, is_best, path=args.save) prog_epoch.set_description( 'Overview ***** best_epoch: {0} best_valid_top1: {1:.2f} ***** Progress' .format(best_epoch, best_top1)) epoch_stats += [copy.deepcopy(stats)] with open(args.epoch_stats_file, 'w') as f: json.dump(epoch_stats, f, cls=utils.NumpyEncoder) utils.list_of_dicts_to_csv(stats_csv, epoch_stats) stats_str = utils.dict_to_log_string(best_stats, key_prepend='best_') logger.info(stats_str) with open(args.stats_file, 'w') as f: arg_dict = vars(args) arg_dict.update(best_stats) json.dump(arg_dict, f, cls=utils.NumpyEncoder) with open(args.epoch_stats_file, 'w') as f: json.dump(epoch_stats, f, cls=utils.NumpyEncoder) utils.list_of_dicts_to_csv(stats_csv, epoch_stats) logger.info('Training of Final Model Complete! Save dir: ' + str(args.save))
def __init__(self, test_args: Namespace, my_dataset: MyDataset, model: nn.Module = None): self.__device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') np.random.seed(test_args.seed) torch.manual_seed(test_args.seed) cudnn.benchmark = True cudnn.enabled = True logging.info(f'gpu device = {test_args.gpu}') logging.info(f'args = {test_args}') if model is None: # equal to: genotype = genotypes.DARTS_v2 if not (test_args.arch or test_args.arch_path): logging.info('need to designate arch.') sys.exit(1) genotype = eval( f'genotypes.{test_args.arch}' ) if not test_args.arch_path else utils.load_genotype( test_args.arch_path) print('Load genotype:', genotype) if my_dataset is MyDataset.CIFAR10: model = NetworkCIFAR(test_args.init_ch, 10, test_args.layers, test_args.auxiliary, genotype).to(self.__device) elif my_dataset is MyDataset.CIFAR100: model = NetworkCIFAR(test_args.init_ch, 100, test_args.layers, test_args.auxiliary, genotype).to(self.__device) elif my_dataset is MyDataset.ImageNet: model = NetworkImageNet(test_args.init_ch, 1000, test_args.layers, test_args.auxiliary, genotype).to(self.__device) else: raise Exception('No match MyDataset') utils.load(model, test_args.model_path, False) model = model.to(self.__device) param_size = utils.count_parameters_in_MB(model) logging.info(f'param size = {param_size}MB') model.drop_path_prob = test_args.drop_path_prob self.__model = model self.__args = test_args self.__criterion = nn.CrossEntropyLoss().to(self.__device) if my_dataset is MyDataset.CIFAR10: _, test_transform = utils._data_transforms_cifar10(test_args) test_data = dset.CIFAR10(root=test_args.data, train=False, download=True, transform=test_transform) elif my_dataset is MyDataset.CIFAR100: _, test_transform = utils._data_transforms_cifar100(test_args) test_data = dset.CIFAR100(root=test_args.data, train=False, download=True, transform=test_transform) elif my_dataset is MyDataset.ImageNet: validdir = test_args.data / 'val' normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) valid_data = dset.ImageFolder( validdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) test_data = valid_data else: raise Exception('No match MyDataset') self.__test_queue = torch.utils.data.DataLoader( test_data, batch_size=test_args.batchsz, shuffle=False, pin_memory=True, num_workers=4)
def __init__(self, args: Namespace, genotype: Genotype, my_dataset: MyDataset, choose_cell=False): self.__args = args self.__dataset = my_dataset self.__previous_epochs = 0 if args.seed is None: raise Exception('designate seed.') elif args.epochs is None: raise Exception('designate epochs.') if not (args.arch or args.arch_path): raise Exception('need to designate arch.') log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) np.random.seed(args.seed) cudnn.benchmark = True cudnn.enabled = True torch.manual_seed(args.seed) logging.info(f'gpu device = {args.gpu}') logging.info(f'args = {args}') logging.info(f'Train genotype: {genotype}') if my_dataset == MyDataset.CIFAR10: self.model = NetworkCIFAR(args.init_ch, 10, args.layers, args.auxiliary, genotype) train_transform, valid_transform = utils._data_transforms_cifar10( args) train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) elif my_dataset == MyDataset.CIFAR100: self.model = NetworkCIFAR(args.init_ch, 100, args.layers, args.auxiliary, genotype) train_transform, valid_transform = utils._data_transforms_cifar100( args) train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) elif my_dataset == MyDataset.ImageNet: self.model = NetworkImageNet(args.init_ch, 1000, args.layers, args.auxiliary, genotype) self.__criterion_smooth = CrossEntropyLabelSmooth( 1000, args.label_smooth).to(device) traindir = os.path.join(args.data, 'train') validdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_data = dset.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), transforms.ToTensor(), normalize, ])) valid_data = dset.ImageFolder( validdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) else: raise Exception('No match Dataset') checkpoint = None if use_DataParallel: print('use Data Parallel') if args.checkpoint_path: checkpoint = torch.load(args.checkpoint_path) utils.load(self.model, checkpoint['state_dict'], args.to_parallel) self.__previous_epochs = checkpoint['epoch'] args.epochs -= self.__previous_epochs if args.epochs <= 0: raise Exception('args.epochs is too small.') self.model = nn.DataParallel(self.model) self.__module = self.model.module torch.cuda.manual_seed_all(args.seed) else: if args.checkpoint_path: checkpoint = torch.load(args.checkpoint_path) utils.load(self.model, checkpoint['state_dict'], args.to_parallel) args.epochs -= checkpoint['epoch'] if args.epochs <= 0: raise Exception('args.epochs is too small.') torch.cuda.manual_seed(args.seed) self.__module = self.model self.model.to(device) param_size = utils.count_parameters_in_MB(self.model) logging.info(f'param size = {param_size}MB') self.__criterion = nn.CrossEntropyLoss().to(device) self.__optimizer = torch.optim.SGD(self.__module.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd) if checkpoint: self.__optimizer.load_state_dict(checkpoint['optimizer']) num_workers = torch.cuda.device_count() * 4 if choose_cell: num_train = len(train_data) # 50000 indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) # 25000 self.__train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batchsz, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[:split]), pin_memory=True, num_workers=num_workers) self.__valid_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batchsz, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:]), pin_memory=True, num_workers=num_workers) else: self.__train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.batchsz, shuffle=True, pin_memory=True, num_workers=num_workers) self.__valid_queue = torch.utils.data.DataLoader( valid_data, batch_size=args.batchsz, shuffle=False, pin_memory=True, num_workers=num_workers) if my_dataset == MyDataset.CIFAR10 or MyDataset.CIFAR100: self.__scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.__optimizer, args.epochs) elif my_dataset == MyDataset.ImageNet: self.__scheduler = torch.optim.lr_scheduler.StepLR( self.__optimizer, args.decay_period, gamma=args.gamma) else: raise Exception('No match Dataset') if checkpoint: self.__scheduler.load_state_dict(checkpoint['scheduler'])