def get_destructive_overthinking_samples(models_path, device='cpu'): sdn_name = 'tinyimagenet_vgg16bn_sdn_ic_only' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task']) output_path = 'only_first' af.create_path(output_path) layer_correct, layer_wrong, layer_predictions, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device) layers = sorted(list(layer_correct.keys())) all_correct = set() for layer in layers[1:]: all_correct = all_correct | layer_correct[layer] only_first = layer_correct[layers[0]] - all_correct for instance_id in only_first: instance_path = dataset.testset_paths.imgs[instance_id][0] filename = os.path.basename(instance_path) print(instance_path) first_predict = layer_predictions[0][instance_id][0] last_predict = layer_predictions[layers[-1]][instance_id][0] first_predict = dataset.testset_paths.classes[first_predict] last_predict = dataset.testset_paths.classes[last_predict] filename = '{}_{}_{}'.format(first_predict, last_predict, filename) copyfile(instance_path, output_path+'/'+filename)
def get_simple_complex(models_path, device='cpu'): sdn_name = 'tinyimagenet_vgg16bn_sdn_ic_only' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task']) output_path = 'simple_complex_images' af.create_path(output_path) dog_path = output_path+'/'+'dog' cat_path = output_path+'/'+'cat' af.create_path(dog_path) af.create_path(cat_path) # n02099601 dog 26 # n02123394 cat 31 layer_correct, layer_wrong, _, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device) layers = sorted(list(layer_correct.keys())) wrong_until = layer_wrong[layers[0]] | layer_correct[layers[0]] for layer in layers[:-1]: instances = layer_correct[layer] & wrong_until wrong_until = wrong_until - layer_correct[layer] print('IC: {}, Num images: {}'.format(layer, len(instances))) for instance_id in instances: instance_path = dataset.testset_paths.imgs[instance_id][0] filename = '{}_{}'.format(layer, os.path.basename(instance_path)) if 'n02099601' in instance_path: copyfile(instance_path, dog_path+'/'+filename) if 'n02123394' in instance_path: copyfile(instance_path, cat_path+'/'+filename)
def train_model(models_path, cr_params, device, num=0): type, mode, pruning, ics = cr_params model, params = arcs.create_resnet_iterative(models_path, type, mode, pruning, ics, False) dataset = af.get_dataset('cifar10') params['name'] = params['base_model'] + '_{}_{}'.format(type, mode) if model.prune: params['name'] += "_prune_{}".format( [x * 100 for x in model.keep_ratio]) print("prune: {}".format(model.keep_ratio)) if mode == "0": params['epochs'] = 250 params['milestones'] = [120, 160, 180] params['gammas'] = [0.1, 0.01, 0.01] if mode == "1": params['epochs'] = 300 params['milestones'] = [100, 150, 200] params['gammas'] = [0.1, 0.1, 0.1] if "full" in type: params['learning_rate'] = 0.1 print("lr: {}".format(params['learning_rate'])) opti_param = (params['learning_rate'], params['weight_decay'], params['momentum'], -1) lr_schedule_params = (params['milestones'], params['gammas']) model.to(device) train_params = dict( epochs=params['epochs'], epoch_growth=[25, 50, 75], epoch_prune=[10, 35, 60, 85, 110, 135, 160], #[10, 35, 60, 85], prune_batch_size=pruning[2], prune_type='2', # 0 skip layer, 1 normal full, 2 iterative reinit=False, min_ratio=[ 0.3, 0.1, 0.05, 0.05 ] # not needed if skip layers, minimum for the iterative pruning ) params['epoch_growth'] = train_params['epoch_growth'] params['epoch_prune'] = train_params['epoch_prune'] optimizer, scheduler = af.get_full_optimizer(model, opti_param, lr_schedule_params) metrics, best_model = model.train_func(model, dataset, train_params, optimizer, scheduler, device) _link_metrics(params, metrics) af.print_sparsity(best_model) arcs.save_model(best_model, params, models_path, params['name'], epoch=-1) print("test acc: {}, last val: {}".format(params['test_top1_acc'], params['valid_top1_acc'][-1])) return best_model, params
def destructive_overthinking_experiment(models_path, device='cpu'): #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True # for the backdoored network add_trigger = False #task = 'cifar10' #task = 'cifar100' task = 'tinyimagenet' network = 'vgg16bn' #network = 'resnet56' #network = 'wideresnet32_4' #network = 'mobilenet' sdn_name = task + '_' + network + '_sdn_ic_only' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger) top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device) print('Top1 Test accuracy: {}'.format(top1_test)) print('Top5 Test accuracy: {}'.format(top5_test)) layer_correct, layer_wrong, _, layer_confidence = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device) layers = sorted(list(layer_correct.keys())) end_wrong = layer_wrong[layers[-1]] cum_correct = set() for layer in layers: cur_correct = layer_correct[layer] cum_correct = cum_correct | cur_correct cur_overthinking = cur_correct & end_wrong print('Output: {}'.format(layer)) print('Current correct: {}'.format(len(cur_correct))) print('Cumulative correct: {}'.format(len(cum_correct))) print('Cur cat. overthinking: {}\n'.format(len(cur_overthinking))) total_confidence = 0.0 for instance in cur_overthinking: total_confidence += layer_confidence[layer][instance] print('Average confidence on destructive overthinking instances:{}'.format(total_confidence/(0.1 + len(cur_overthinking)))) total_confidence = 0.0 for instance in cur_correct: total_confidence += layer_confidence[layer][instance] print('Average confidence on correctly classified :{}'.format(total_confidence/(0.1 + len(cur_correct))))
def wasteful_overthinking_experiment(models_path, device='cpu'): #task = 'cifar10' #task = 'cifar100' task = 'tinyimagenet' network = 'vgg16bn' #network = 'resnet56' #network = 'wideresnet32_4' #network = 'mobilenet' sdn_name = task + '_' + network + '_sdn_ic_only' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task']) top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device) print('Top1 Test accuracy: {}'.format(top1_test)) print('Top5 Test accuracy: {}'.format(top5_test)) layer_correct, _, _, _ = mf.sdn_get_detailed_results(sdn_model, loader=dataset.test_loader, device=device) layers = sorted(list(layer_correct.keys())) end_correct = layer_correct[layers[-1]] total = 10000 # to quantify the computational waste c_i = [0.15, 0.3, 0.45, 0.6, 0.75, 0.9] total_comp = 0 cum_correct = set() for layer in layers: cur_correct = layer_correct[layer] unique_correct = cur_correct - cum_correct cum_correct = cum_correct | cur_correct print('Output: {}'.format(layer)) print('Current correct: {}'.format(len(cur_correct))) print('Cumulative correct: {}'.format(len(cum_correct))) print('Unique correct: {}\n'.format(len(unique_correct))) if layer < layers[-1]: total_comp += len(unique_correct) * c_i[layer] else: total_comp += total - (len(cum_correct) - len(unique_correct)) print('Total Comp: {}'.format(total_comp))
def train_model(models_path, device): _, sdn = arcs.create_resnet56(models_path, 'cifar10', save_type='d') print('snd name: {}'.format(sdn)) # train_sdn(models_path, sdn, device) print("Training model...") trained_model, model_params = arcs.load_model(models_path, sdn, 0) dataset = af.get_dataset(model_params['task']) lr = model_params['learning_rate'] momentum = model_params['momentum'] weight_decay = model_params['weight_decay'] milestones = model_params['milestones'] gammas = model_params['gammas'] num_epochs = model_params['epochs'] model_params['optimizer'] = 'SGD' opti_param = (lr, weight_decay, momentum, -1) lr_schedule_params = (milestones, gammas) optimizer, scheduler = af.get_full_optimizer(trained_model, opti_param, lr_schedule_params) trained_model_name = sdn + '_training' print('Training: {}...'.format(trained_model_name)) trained_model.to(device) metrics = trained_model.train_func(trained_model, dataset, num_epochs, optimizer, scheduler, device=device) model_params['train_top1_acc'] = metrics['train_top1_acc'] model_params['test_top1_acc'] = metrics['test_top1_acc'] model_params['train_top3_acc'] = metrics['train_top3_acc'] model_params['test_top3_acc'] = metrics['test_top3_acc'] model_params['epoch_times'] = metrics['epoch_times'] model_params['lrs'] = metrics['lrs'] total_training_time = sum(model_params['epoch_times']) model_params['total_time'] = total_training_time print('Training took {} seconds...'.format(total_training_time)) arcs.save_model(trained_model, model_params, models_path, trained_model_name, epoch=-1) return trained_model, dataset
def check_performance(trained_models_path, device): add_trigger = False task = 'tinyimagenet' network = 'vgg16bn' sdn_name = task + '_' + network + '_sdn_sdn_training_ds' sdn_model, sdn_params = arcs.load_model(trained_models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger) policy_net_all = [] for i in range(sum(sdn_params['add_ic'])): policy_net = PolicyNet('tiny_imagenet', 200).to(device) policy_net.load_state_dict(torch.load('./policy_%d.dump' % (i + 1))) policy_net_all.append(policy_net) predictions = list() stops = list() for i, batch in enumerate(dataset.test_loader): val_x, val_y = batch val_x = val_x.to(device) val_y = val_y.to(device) with torch.no_grad(): xhs = sdn_model(val_x.to(device)) predictions.append(torch.stack(xhs)) policy_pred = list() for t in range(sum(sdn_params['add_ic'])): policy_tmp = policy_net_all[t](val_x, xhs[t]) policy_pred.append(policy_tmp) policy_pred = torch.cat(policy_pred, axis=-1) stops.append(policy_pred) stops = torch.cat(stops, axis=0) predictions = torch.cat(predictions, axis=1)
def early_exit_experiments(models_path, device='cpu'): sdn_training_type = 'ic_only' # IC-only training #sdn_training_type = 'sdn_training' # SDN training # task = 'cifar10' # task = 'cifar100' task = 'tinyimagenet' #sdn_names = ['vgg16bn_sdn', 'resnet56_sdn', 'wideresnet32_4_sdn', 'mobilenet_sdn']; add_trigger = False sdn_names = ['vgg16bn_sdn']; add_trigger = False sdn_names = [task + '_' + sdn_name + '_' + sdn_training_type for sdn_name in sdn_names] for sdn_name in sdn_names: cnn_name = sdn_name.replace('sdn', 'cnn') cnn_name = cnn_name.replace('_ic_only', '') cnn_name = cnn_name.replace('_sdn_training', '') print(sdn_name) print(cnn_name) sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task']) cnn_model, _ = arcs.load_model(models_path, cnn_name, epoch=-1) cnn_model.to(device) print('Get CNN results') top1_test, top5_test, total_time = mf.cnn_test_time(cnn_model, dataset.test_loader, device) total_ops, total_params = profile(cnn_model, cnn_model.input_size, device) print("#Ops: %f GOps"%(total_ops/1e9)) print("#Parameters: %f M"%(total_params/1e6)) print('Top1 Test accuracy: {}'.format(top1_test)) #print('Top5 Test accuracy: {}'.format(top5_test)) print('25 percent cost: {}'.format((total_ops/1e9)*0.25)) print('50 percent cost: {}'.format((total_ops/1e9)*0.5)) print('75 percent cost: {}'.format((total_ops/1e9)*0.75)) # to test early-exits with the SDN one_batch_dataset = af.get_dataset(sdn_params['task'], 1) print('Get SDN early exit results') total_ops, total_params = profile_sdn(sdn_model, sdn_model.input_size, device) print("#Ops (GOps): {}".format(total_ops)) print("#Params (mil): {}".format(total_params)) top1_test, top5_test = mf.sdn_test(sdn_model, dataset.test_loader, device) print('Top1 Test accuracy: {}'.format(top1_test)) #print('Top5 Test accuracy: {}'.format(top5_test)) print('Calibrate confidence_thresholds') confidence_thresholds = [0.1, 0.15, 0.25, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999] # search for the confidence threshold for early exits sdn_model.forward = sdn_model.early_exit for threshold in confidence_thresholds: print(threshold) sdn_model.confidence_threshold = threshold # change the forward func for sdn to forward with cascade top1_test, top5_test, early_exit_counts, non_conf_exit_counts, total_time = mf.sdn_test_early_exits(sdn_model, one_batch_dataset.test_loader, device) average_mult_ops = 0 total_num_instances = 0 for output_id, output_count in enumerate(early_exit_counts): average_mult_ops += output_count*total_ops[output_id] total_num_instances += output_count for output_count in non_conf_exit_counts: total_num_instances += output_count average_mult_ops += output_count*total_ops[output_id] average_mult_ops /= total_num_instances print('Early exit Counts:') print(early_exit_counts) print('Non confident exit counts:') print(non_conf_exit_counts) print('Top1 Test accuracy: {}'.format(top1_test)) print('Top5 Test accuracy: {}'.format(top5_test)) print('SDN cascading took {} seconds.'.format(total_time)) print('Average Mult-Ops: {}'.format(average_mult_ops))
def train(models_path, untrained_models, sdn=False, ic_only_sdn=False, device='cpu', ds=False): print('Training models...') for base_model in untrained_models: trained_model, model_params = arcs.load_model(models_path, base_model, 0) dataset = af.get_dataset(model_params['task']) learning_rate = model_params['learning_rate'] momentum = model_params['momentum'] weight_decay = model_params['weight_decay'] milestones = model_params['milestones'] gammas = model_params['gammas'] num_epochs = model_params['epochs'] model_params['optimizer'] = 'SGD' if ic_only_sdn: # IC-only training, freeze the original weights learning_rate = model_params['ic_only']['learning_rate'] num_epochs = model_params['ic_only']['epochs'] milestones = model_params['ic_only']['milestones'] gammas = model_params['ic_only']['gammas'] model_params['optimizer'] = 'Adam' trained_model.ic_only = True else: trained_model.ic_only = False if ds: trained_model.ds = True else: trained_model.ds = False optimization_params = (learning_rate, weight_decay, momentum) lr_schedule_params = (milestones, gammas) # pdb.set_trace() if sdn: if ic_only_sdn: optimizer, scheduler = af.get_sdn_ic_only_optimizer( trained_model, optimization_params, lr_schedule_params) trained_model_name = base_model + '_ic_only_ic{}'.format( np.sum(model_params['add_ic'])) else: optimizer, scheduler = af.get_full_optimizer( trained_model, optimization_params, lr_schedule_params) trained_model_name = base_model + '_sdn_training_ic{}'.format( np.sum(model_params['add_ic'])) else: optimizer, scheduler = af.get_full_optimizer( trained_model, optimization_params, lr_schedule_params) trained_model_name = base_model if ds: trained_model_name = trained_model_name + '_ds' # pdb.set_trace() print('Training: {}...'.format(trained_model_name)) # trained_model = nn.DataParallel(trained_model) trained_model.to(device) metrics = trained_model.train_func(trained_model, dataset, num_epochs, optimizer, scheduler, device=device) model_params['train_top1_acc'] = metrics['train_top1_acc'] model_params['test_top1_acc'] = metrics['test_top1_acc'] model_params['train_top5_acc'] = metrics['train_top5_acc'] model_params['test_top5_acc'] = metrics['test_top5_acc'] model_params['epoch_times'] = metrics['epoch_times'] model_params['lrs'] = metrics['lrs'] total_training_time = sum(model_params['epoch_times']) model_params['total_time'] = total_training_time print('Training took {} seconds...'.format(total_training_time)) arcs.save_model(trained_model, model_params, models_path, trained_model_name, epoch=-1)
def iter_training_0(model, data, params, optimizer, scheduler, device='cpu'): print("iter training 0") augment = model.augment_training metrics = { 'epoch_times': [], 'valid_top1_acc': [], 'valid_top3_acc': [], 'train_top1_acc': [], 'train_top3_acc': [], 'test_top1_acc': [], 'test_top3_acc': [], 'lrs': [] } epochs, epoch_growth, epoch_prune = params['epochs'], params[ 'epoch_growth'], params['epoch_prune'] pruning_batch_size, pruning_type, reinit = params[ 'prune_batch_size'], params['prune_type'], params['reinit'] #epoch_growth = [25, 50, 75] # [(i + 1) * epochs / (model.num_ics + 1) for i in range(model.num_ics)] print("array params: num_ics {}, epochs {}".format(model.num_ics, epochs)) print("epochs growth: {}".format(epoch_growth)) print("epochs prune: {}".format(epoch_prune)) max_coeffs = calc_coeff(model) print('max_coeffs: {}'.format(max_coeffs)) model.to(device) model.to_train() if model.prune: prune_dataset = af.get_dataset('cifar10', batch_size=pruning_batch_size) print("pruning_batch_size: {}, prune_type: {}, reinit: {}".format( pruning_batch_size, pruning_type, reinit)) print("min_ratio: {}".format(params['min_ratio'])) print("keep_ratio: {}".format(model.keep_ratio)) best_model, accuracies, best_epoch = None, None, 0 masks = [] mask1 = None block_to_prune = 0 for epoch in range(1, epochs + 1): print('\nEpoch: {}/{}'.format(epoch, epochs)) if epoch in epoch_growth: grown_layers = model.grow() model.to(device) optimizer.add_param_group({'params': grown_layers}) print("model grow") if epoch in epoch_prune and model.prune: loader = get_loader(prune_dataset, False) if pruning_type == '0': mask1 = prune_skip_layer(model, model.keep_ratio, loader, sdn_loss, block_to_prune, mask1, device, reinit) block_to_prune += 1 elif pruning_type == '1': prune2(model, model.keep_ratio, loader, sdn_loss, device) elif pruning_type == "2": steps = [] _epoch_growth = [1] + epoch_growth for i in range(len(_epoch_growth)): if epoch < _epoch_growth[i]: steps.append(0) else: v = [ 0 if e < _epoch_growth[i] or e > epoch else 1 for e in epoch_prune ] steps.append(sum(v) - 1) mask = prune_iterative(model, model.keep_ratio, params['min_ratio'], steps, loader, sdn_loss, device, reinit) masks.append(mask) mask1 = mask epoch_routine(model, data, optimizer, scheduler, epoch, epochs, augment, metrics, device) if model.num_output == model.num_ics + 1: if model.prune and epoch >= epoch_prune[-1]: print("pruning for best_model") best_model, accuracies, best_epoch = best_model_def( best_model, model, accuracies, best_epoch, metrics, epoch) elif not model.prune: best_model, accuracies, best_epoch = best_model_def( best_model, model, accuracies, best_epoch, metrics, epoch) af.print_sparsity(model) metrics['test_top1_acc'], metrics['test_top3_acc'] = sdn_test( best_model, data.test_loader, device) test_top1, _ = sdn_test(model, data.test_loader, device) metrics['best_model_epoch'] = best_epoch metrics['masks'] = masks print("best epoch: {}".format(best_epoch)) print("comparison best and latest: {}/{}".format(metrics['test_top1_acc'], test_top1)) return metrics, best_model
def policy_training(models_path, device='cpu'): #sdn_name = 'cifar10_vgg16bn_bd_sdn_converted'; add_trigger = True # for the backdoored network add_trigger = False #task = 'cifar10' # task = 'cifar100' task = 'tinyimagenet' network = 'vgg16bn' #network = 'resnet56' #network = 'wideresnet32_4' #network = 'mobilenet' # sdn_name = task + '_' + network + '_sdn_ic_only' sdn_name = task + '_' + network + '_sdn_ic_only_ic1' # sdn_name = task + '_' + network + '_sdn_ic_only_ic1_ds' # sdn_name = task + '_' + network + '_sdn_sdn_training' # sdn_name = task + '_' + network + '_sdn_sdn_training_ds' # sdn_name = task + '_' + network + '_sdn_sdn_training_ic14_ds' sdn_model, sdn_params = arcs.load_model(models_path, sdn_name, epoch=-1) sdn_model.to(device) dataset = af.get_dataset(sdn_params['task'], add_trigger=add_trigger) # need to construct the policy network and train the policy net. # the architecture of the policy network need to be designed. ###################################### # need to think about the model of policynet ###################################### sdn_model.eval() p_true_all = list() xhs_all = list() y_all = list() for batch in dataset.val_loader: x, y = batch x = x.to(device) y = y.to(device) batch_size = y.shape[0] with torch.no_grad(): xhs = sdn_model(x) categories = xhs[-1].shape[-1] # pdb.set_trace() # internal_fm = sdn_model.internal_fm # sdn_model.internal_fm = [None]*len(internal_fm) p_true, _ = PolicyKL.true_posterior(cmd_args, xhs, y) xhs_all.append(xhs) y_all.append(y) p_true_all.append(p_true) p_true = torch.cat(p_true_all, dim=0) p_det = max_onehot(p_true, dim=-1, device=device) p_true = torch.mean(p_true, dim=0) # find positions with nonzero posterior train_post = {} nz_post = {} i = 0 for t in range(cmd_args.num_output): if p_true[t] > 0.001: train_post[i] = t nz_post[i] = t i += 1 del train_post[i - 1] p_str = 'val p true:[' p_str += ','.join(['%0.3f' % p_true[t] for t in nz_post.values()]) print(p_str + ']') p_det = torch.mean(p_det, dim=0) p_str = 'val p true det:[' p_str += ','.join(['%0.3f' % p_det[t] for t in nz_post.values()]) print(p_str + ']') ###################################### #### #check the performance based on confidence score #### y_all = torch.cat(y_all, dim=-1) xhs_all = list(zip(*xhs_all)) for i in range(len(xhs_all)): xhs_all[i] = torch.cat(xhs_all[i], dim=0) print('The {}th classifier performance:'.format(i)) prec1, prec5 = data.accuracy(xhs_all[i], y_all, topk=(1, 5)) print('Top1 Test accuracy: {}'.format(prec1)) print('Top5 Test accuracy: {}'.format(prec5)) xhs_all = list(map(lambda x: F.softmax(x, dim=-1), xhs_all)) max_confidences = list(map(lambda x: torch.max(x, dim=-1)[0], xhs_all)) max_confidences = torch.stack(max_confidences, dim=-1) xhs_all_stack = torch.stack(xhs_all, dim=1) predictions = list(map(lambda x: torch.argmax(x, dim=-1), xhs_all)) predictions = torch.stack(predictions, dim=-1) thresholds = [ 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999, -1 ] # thresholds = [0.8, 0.9, 0.95, 0.99, 0.999, -1] for threshold in thresholds: if threshold == -1: index = torch.argmax(max_confidences, dim=-1).cpu().numpy() # pdb.set_trace() else: mask = (max_confidences > threshold).to(int).cpu().numpy() mask[:, -1] = 1 index = np.array(list(map(lambda x: list(x).index(1), list(mask)))) results = xhs_all_stack.gather( 1, torch.Tensor([index] * 200).t().view( -1, 1, 200).long().to(device)).squeeze() prec1, prec5 = data.accuracy(results, y_all, topk=(1, 5)) print('htreshold: ', threshold) print('Top1 Test accuracy: {}'.format(prec1)) print('Top5 Test accuracy: {}'.format(prec5)) #### #confidence score check finish #### # pdb.set_trace() internal_fm = [torch.rand(2, 2) for i in range(cmd_args.num_output)] # initialize nets with nonzero posterior if cmd_args.model_type == 'sequential': score_net = MNIconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) # print('Sequential model to be implemented') if cmd_args.model_type == 'multiclass': score_net = MulticlassNetImage(cmd_args, x, internal_fm, train_post, category=categories) score_net.to(device) if cmd_args.model_type == 'confidence': score_net = MNIconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) if cmd_args.model_type == 'imiconfidence': score_net = Imiconfidence(cmd_args, x, internal_fm, train_post, category=categories, share=cmd_args.share, net_size=cmd_args.net_size) score_net.to(device) # train if cmd_args.phase == 'train': # start training optimizer = optim.Adam(list(score_net.parameters()), lr=cmd_args.learning_rate, weight_decay=cmd_args.weight_decay) milestones = [10, 20, 40, 60, 80] gammas = [0.4, 0.2, 0.2, 0.2, 0.2] scheduler = MultiStepMultiLR(optimizer, milestones=milestones, gammas=gammas) trainer = PolicyKL(args=cmd_args, sdn_model=sdn_model, score_net=score_net, train_post=train_post, nz_post=nz_post, optimizer=optimizer, data_loader=dataset, device=device, scheduler=scheduler, sdn_name=sdn_name) trainer.train() #pdb.set_trace() # test dump = cmd_args.save_dir + '/{}_best_val_policy.dump'.format(sdn_name) print('Loading model...') score_net.load_state_dict(torch.load(dump)) PolicyKL.test(args=cmd_args, score_net=score_net, sdn_model=sdn_model, data_loader=dataset.test_loader, nz_post=nz_post, device=device) print(cmd_args.save_dir)