def __init__(self, flags, agents, gpu_nr, memory_fraction, weights, lock, scale, staleness_aware, epoch): super().__init__(flags, agents, gpu_nr, memory_fraction, weights, lock) from tuner_utils.yellowfin import YFOptimizer self.times = [] self.flags = flags self.agents = agents self.gpu_nr = gpu_nr self.lock = lock self.scale = scale self.staleness_aware = staleness_aware self.nesterov = True self.lr_dict = iter([(1, 0.4), (82, 0.04), (123, 0.004), (165, 0.0001)]) self.next_switch = next(self.lr_dict) self.learning_rate = 0 self.batchsize = self.flags.batch_size self.epochs = epoch if self.flags.drop_remainder: self.iterations_in_epoch = self.agents * math.floor( self.flags.train_set_size / (self.flags.batch_size)) else: self.iterations_in_epoch = self.agents * math.ceil( self.flags.train_set_size / (self.flags.batch_size)) logger.state("iterations in epoch", self.iterations_in_epoch) if 'cuda' in self.flags.device: self.device = torch.device('cuda:' + str(self.gpu_nr)) else: pass self.device = torch.device('cpu') self.weight = [] self.buf = [] for w in weights: t = torch.tensor(w, device=self.device, requires_grad=True) t.grad = torch.tensor(w, device=self.device) self.weight.append(t) self.optimizer = YFOptimizer(self.weight, lr=1, mu=0.9) self.optimizer.zero_grad()
class YellowFin(UpdadteFunction): """ This class is a wrapper for the YellowFin implementation found at: https://github.com/AnonRepository/YellowFin_Pytorch and described in this paper: https://arxiv.org/pdf/1706.03471.pdf Because the code is not distributed under any licence, we cannot include it in the project. to make this project work download the tuner_utils folder from the github page and add it to the project folder. Allso has to be patched to work with newer versions of pytorch. (add .numpy to line 275 and 283) """ def __init__(self, flags, agents, gpu_nr, memory_fraction, weights, lock, scale, staleness_aware, epoch): super().__init__(flags, agents, gpu_nr, memory_fraction, weights, lock) from tuner_utils.yellowfin import YFOptimizer self.times = [] self.flags = flags self.agents = agents self.gpu_nr = gpu_nr self.lock = lock self.scale = scale self.staleness_aware = staleness_aware self.nesterov = True self.lr_dict = iter([(1, 0.4), (82, 0.04), (123, 0.004), (165, 0.0001)]) self.next_switch = next(self.lr_dict) self.learning_rate = 0 self.batchsize = self.flags.batch_size self.epochs = epoch if self.flags.drop_remainder: self.iterations_in_epoch = self.agents * math.floor( self.flags.train_set_size / (self.flags.batch_size)) else: self.iterations_in_epoch = self.agents * math.ceil( self.flags.train_set_size / (self.flags.batch_size)) logger.state("iterations in epoch", self.iterations_in_epoch) if 'cuda' in self.flags.device: self.device = torch.device('cuda:' + str(self.gpu_nr)) else: pass self.device = torch.device('cpu') self.weight = [] self.buf = [] for w in weights: t = torch.tensor(w, device=self.device, requires_grad=True) t.grad = torch.tensor(w, device=self.device) self.weight.append(t) self.optimizer = YFOptimizer(self.weight, lr=1, mu=0.9) self.optimizer.zero_grad() def learning_rate_func(self, epoch, update): if self.next_switch is not None and epoch >= self.next_switch[0]: logger.state("change of learning rate", epoch, self.learning_rate, self.next_switch[1]) self.learning_rate = self.next_switch[1] logger.state("lr", self.learning_rate) try: self.next_switch = next(self.lr_dict) except StopIteration: self.next_switch = None if self.epochs >= epoch: # print("warm up optimizer", flush=True) return self.learning_rate else: return self.learning_rate / self.scale def __call__(self, weights, update, gradients, staleness, epoch): """ :param weights: Copy of the model weights :param update: Curent update :param gradients: List of gradients :param staleness: Staleness of each gradient :param epoch: Current epoch """ lrr = self.learning_rate_func(epoch, update) lr = torch.tensor(-lrr * len(gradients), dtype=torch.float, device=self.device) self.optimizer.set_lr_factor(lrr) start_time = time.time() if not self.staleness_aware: gradient = gradients else: gradient = [np.divide(g, s) for g, s in zip(gradients, staleness)] grad = np.mean(gradient, axis=0) i = 0 for elem in self.optimizer._optimizer.param_groups: for p in elem['params']: p.grad.data.copy_(torch.from_numpy(grad[i]), non_blocking=True) i += 1 self.optimizer.step() c = 0 wei = [] with self.lock: for elem in self.optimizer._optimizer.param_groups: for p in elem['params']: wei.append(p.data.numpy()) c += 1 end_time = time.time() if self.flags.time_program: self.times.append(end_time - start_time) return wei def __del__(self): if self.flags.time_program: if self.times != []: t = np.mean(self.times, axis=0) logger.state("Optimizer took", t, flush=True) def close(self): pass
def main(): # Init logger6 if not os.path.isdir(args.save_path): os.makedirs(args.save_path) log = open( os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w') print_log('save path : {}'.format(args.save_path), log) state = {k: v for k, v in args._get_kwargs()} print_log(state, log) print_log("Random Seed: {}".format(args.manualSeed), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) # Init the tensorboard path and writer tb_path = os.path.join(args.save_path, 'tb_log') # logger = Logger(tb_path) writer = SummaryWriter(tb_path) # Init dataset if not os.path.isdir(args.data_path): os.makedirs(args.data_path) if args.dataset == 'cifar10': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] elif args.dataset == 'cifar100': mean = [x / 255 for x in [129.3, 124.1, 112.4]] std = [x / 255 for x in [68.2, 65.4, 70.4]] elif args.dataset == 'svhn': mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] elif args.dataset == 'mnist': mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] elif args.dataset == 'imagenet': mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] else: assert False, "Unknow dataset : {}".format(args.dataset) if args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std) ]) # here is actually the validation dataset else: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) if args.dataset == 'mnist': train_data = dset.MNIST(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.MNIST(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar10': train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar100': train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) num_classes = 100 elif args.dataset == 'svhn': train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'stl10': train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'imagenet': train_dir = os.path.join(args.data_path, 'train') test_dir = os.path.join(args.data_path, 'val') train_data = dset.ImageFolder(train_dir, transform=train_transform) test_data = dset.ImageFolder(test_dir, transform=test_transform) num_classes = 1000 else: assert False, 'Do not support dataset : {}'.format(args.dataset) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print_log("=> creating model '{}'".format(args.arch), log) # Init model, criterion, and optimizer net = models.__dict__[args.arch](num_classes) print_log("=> network :\n {}".format(net), log) if args.use_cuda: if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) else: net = torch.nn.DataParallel(net, device_ids=[0]) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss() # params without threshold all_param = [ param for name, param in net.named_parameters() if not 'delta_th' in name ] th_param = [ param for name, param in net.named_parameters() if 'delta_th' in name ] if args.optimizer == "SGD": print("using SGD as optimizer") optimizer = torch.optim.SGD(all_param, lr=state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) optimizer_th = torch.optim.SGD(th_param, lr=state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) elif args.optimizer == "Adam": print("using Adam as optimizer") optimizer = torch.optim.Adam(all_param, lr=state['learning_rate'], weight_decay=state['decay']) optimizer_th = torch.optim.SGD(th_param, lr=state['learning_rate'], momentum=state['momentum'], weight_decay=0, nesterov=True) elif args.optimizer == "YF": print("using YellowFin as optimizer") optimizer = YFOptimizer(filter(lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'], mu=state['momentum'], weight_decay=state['decay']) # optimizer = YFOptimizer( filter(lambda param: param.requires_grad, net.parameters()) ) elif args.optimizer == "RMSprop": print("using RMSprop as optimizer") optimizer = torch.optim.RMSprop(filter( lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'], alpha=0.99, eps=1e-08, weight_decay=0, momentum=0) if args.use_cuda: net.cuda() criterion.cuda() recorder = RecorderMeter(args.epochs) # count number of epoches # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume) if not (args.fine_tune): args.start_epoch = checkpoint['epoch'] recorder = checkpoint['recorder'] optimizer.load_state_dict(checkpoint['optimizer']) state_tmp = net.state_dict() if 'state_dict' in checkpoint.keys(): state_tmp.update(checkpoint['state_dict']) else: state_tmp.update(checkpoint) net.load_state_dict(state_tmp) print_log( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log( "=> do not use any checkpoint for {} model".format(args.arch), log) # Right after the pretrained model is loaded: ''' when model is loaded with the pre-trained model, the original initialized threshold are not correct anymore, which might be clipped by the hard-tanh function. ''' for name, module in net.named_modules(): name = name.replace('.', '/') class_name = str(module.__class__).split('.')[-1].split("'")[0] if "quanConv2d" in class_name or "quanLinear" in class_name: module.delta_th.data = module.weight.abs().max( ) * module.init_factor.cuda() if args.evaluate: validate(test_loader, net, criterion, log) return # set the graident register hook to modify the gradient (gradient clipping) for name, param in net.named_parameters(): if "delta_th" in name: # if "delta_th" in name and 'classifier' in name: # based on previous experiment, the clamp interval would better range between 0.001 param.register_hook(lambda grad: grad.clamp(min=-0.001, max=0.001)) # Main loop start_time = time.time() epoch_time = AverageMeter() for epoch in range(args.start_epoch, args.epochs): current_learning_rate, current_momentum = adjust_learning_rate( optimizer, epoch, args.gammas, args.schedule) current_learning_rate, current_momentum = adjust_learning_rate( optimizer_th, epoch, args.gammas, args.schedule) # Display simulation time need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log( '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate, current_momentum) \ + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100 - recorder.max_accuracy(False)), log) # ============ TensorBoard logging ============# # we show the model param initialization to give a intuition when we do the fine tuning for name, param in net.named_parameters(): name = name.replace('.', '/') if "delta_th" not in name: writer.add_histogram(name, param.cpu().detach().numpy(), epoch) for name, module in net.named_modules(): name = name.replace('.', '/') class_name = str(module.__class__).split('.')[-1].split("'")[0] if "quanConv2d" in class_name or "quanLinear" in class_name: sparsity = Sparsity_check(module) writer.add_scalar(name + '/sparsity/', sparsity, epoch) # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1) # ============ TensorBoard logging ============# # train for one epoch train_acc, train_los = train(train_loader, net, criterion, optimizer, optimizer_th, epoch, log) # evaluate on validation set val_acc, val_los = validate(test_loader, net, criterion, log) recorder.update(epoch, train_los, train_acc, val_los, val_acc) is_best = val_acc >= recorder.max_accuracy(False) if args.model_only: checkpoint_state = {'state_dict': net.state_dict} else: checkpoint_state = { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': net.state_dict(), 'recorder': recorder, 'optimizer': optimizer.state_dict(), } save_checkpoint(checkpoint_state, is_best, args.save_path, 'checkpoint.pth.tar', log) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() recorder.plot_curve(os.path.join(args.save_path, 'curve.png')) # save addition accuracy log for plotting accuracy_logger(base_dir=args.save_path, epoch=epoch, train_accuracy=train_acc, test_accuracy=val_acc) # ============ TensorBoard logging ============# for name, param in net.named_parameters(): name = name.replace('.', '/') writer.add_histogram(name + '/grad', param.grad.cpu().detach().numpy(), epoch + 1) # for name, module in net.named_modules(): # name = name.replace('.', '/') # class_name = str(module.__class__).split('.')[-1].split("'")[0] # if "quanConv2d" in class_name or "quanLinear" in class_name: # sparsity = Sparsity_check(module) # writer.add_scalar(name + '/sparsity/', sparsity, epoch + 1) # # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1) for name, module in net.named_modules(): name = name.replace('.', '/') class_name = str(module.__class__).split('.')[-1].split("'")[0] if "quanConv2d" in class_name or "quanLinear" in class_name: if module.delta_th.data is not None: if module.delta_th.dim( ) == 0: # zero-dimension tensor (scalar) not iterable writer.add_scalar(name + '/delta/', module.delta_th.detach(), epoch + 1) else: for idx, delta in enumerate(module.delta_th.detach()): writer.add_scalar( name + '/delta/' + '{}'.format(idx), delta, epoch + 1) writer.add_scalar('loss/train_loss', train_los, epoch + 1) writer.add_scalar('loss/test_loss', val_los, epoch + 1) writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1) writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1) # ============ TensorBoard logging ============# log.close()
def init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, logfile_header=None): # Choose model. logger.Log("Building model.") if FLAGS.model_type == "CBOW": build_model = spinn.cbow.build_model elif FLAGS.model_type == "RNN": build_model = spinn.plain_rnn.build_model elif FLAGS.model_type == "SPINN": build_model = spinn.spinn_core_model.build_model elif FLAGS.model_type == "RLSPINN": build_model = spinn.rl_spinn.build_model elif FLAGS.model_type == "Pyramid": build_model = spinn.pyramid.build_model elif FLAGS.model_type == "ChoiPyramid": build_model = spinn.choi_pyramid.build_model else: raise NotImplementedError # Input Encoder. context_args = Args() context_args.reshape_input = lambda x, batch_size, seq_length: x context_args.reshape_context = lambda x, batch_size, seq_length: x context_args.input_dim = FLAGS.word_embedding_dim if FLAGS.encode == "projection": encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim) context_args.input_dim = FLAGS.model_dim elif FLAGS.encode == "gru": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim, num_layers=FLAGS.encode_num_layers, bidirectional=FLAGS.encode_bidirectional, reverse=FLAGS.encode_reverse) elif FLAGS.encode == "attn": context_args.reshape_input = lambda x, batch_size, seq_length: x.view( batch_size, seq_length, -1) context_args.reshape_context = lambda x, batch_size, seq_length: x.view( batch_size * seq_length, -1) context_args.input_dim = FLAGS.model_dim encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim) elif FLAGS.encode == "pass": def encoder(x): return x else: raise NotImplementedError context_args.encoder = encoder # Composition Function. composition_args = Args() composition_args.lateral_tracking = FLAGS.lateral_tracking composition_args.tracking_ln = FLAGS.tracking_ln composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition composition_args.size = FLAGS.model_dim composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim composition_args.use_internal_parser = FLAGS.use_internal_parser composition_args.transition_weight = FLAGS.transition_weight composition_args.wrap_items = lambda x: torch.cat(x, 0) composition_args.extract_h = lambda x: x composition_args.extract_c = None composition_args.detach = FLAGS.transition_detach composition_args.evolution = FLAGS.evolution if FLAGS.reduce == "treelstm": assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.' if FLAGS.model_dim != FLAGS.word_embedding_dim: print('If you are setting different hidden layer and word ' 'embedding sizes, make sure you specify an encoder') composition_args.wrap_items = lambda x: bundle(x) composition_args.extract_h = lambda x: x.h composition_args.extract_c = lambda x: x.c composition_args.size = FLAGS.model_dim / 2 composition = ReduceTreeLSTM(FLAGS.model_dim / 2, tracker_size=FLAGS.tracking_lstm_hidden_dim, use_tracking_in_composition=FLAGS.use_tracking_in_composition, composition_ln=FLAGS.composition_ln) elif FLAGS.reduce == "tanh": class ReduceTanh(nn.Module): def forward(self, lefts, rights, tracking=None): batch_size = len(lefts) ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0)) return torch.chunk(ret, batch_size, 0) composition = ReduceTanh() elif FLAGS.reduce == "treegru": composition = ReduceTreeGRU(FLAGS.model_dim, FLAGS.tracking_lstm_hidden_dim, FLAGS.use_tracking_in_composition) else: raise NotImplementedError composition_args.composition = composition model = build_model(data_manager, initial_embeddings, vocab_size, num_classes, FLAGS, context_args, composition_args) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) elif FLAGS.optimizer_type == "YellowFin": optimizer = YFOptimizer(model.parameters(), lr=FLAGS.learning_rate) if FLAGS.actively_decay_learning_rate: logger.Log( "WARNING: Ignoring actively_decay_learning_rate and learning_rate_decay_per_10k_steps. Not implemeted for YellowFin.") else: raise NotImplementedError # Build trainer. if FLAGS.evolution: trainer = ModelTrainer_ES(model, optimizer) else: trainer = ModelTrainer(model, optimizer) # Print model size. logger.Log("Architecture: {}".format(model)) if logfile_header: logfile_header.model_architecture = str(model) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) if logfile_header: logfile_header.total_params = int(total_params) return model, optimizer, trainer
def main(): # Init logger6 if not os.path.isdir(args.save_path): os.makedirs(args.save_path) if not os.path.isdir(os.path.join(args.save_path, 'saved_tensors')): os.makedirs(os.path.join(args.save_path, 'saved_tensors')) log = open( os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w') print_log('save path : {}'.format(args.save_path), log) state = {k: v for k, v in args._get_kwargs()} print_log(state, log) print_log("Random Seed: {}".format(args.manualSeed), log) print_log("python version : {}".format(sys.version.replace('\n', ' ')), log) print_log("torch version : {}".format(torch.__version__), log) print_log("cudnn version : {}".format(torch.backends.cudnn.version()), log) # Init the tensorboard path and writer tb_path = os.path.join(args.save_path, 'tb_log') # logger = Logger(tb_path) # writer = SummaryWriter(tb_path) # Init dataset if not os.path.isdir(args.data_path): os.makedirs(args.data_path) if args.dataset == 'cifar10': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] elif args.dataset == 'cifar100': mean = [x / 255 for x in [129.3, 124.1, 112.4]] std = [x / 255 for x in [68.2, 65.4, 70.4]] elif args.dataset == 'svhn': mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] elif args.dataset == 'mnist': mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] elif args.dataset == 'imagenet': mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] else: assert False, "Unknow dataset : {}".format(args.dataset) if args.dataset == 'imagenet': train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std) ]) # here is actually the validation dataset else: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std) ]) test_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)]) if args.dataset == 'mnist': train_data = dset.MNIST(args.data_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_data = dset.MNIST(args.data_path, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) num_classes = 10 elif args.dataset == 'cifar10': train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'cifar100': train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True) test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True) num_classes = 100 elif args.dataset == 'svhn': train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'stl10': train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True) test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True) num_classes = 10 elif args.dataset == 'imagenet': train_dir = os.path.join(args.data_path, 'train') test_dir = os.path.join(args.data_path, 'val') train_data = dset.ImageFolder(train_dir, transform=train_transform) test_data = dset.ImageFolder(test_dir, transform=test_transform) num_classes = 1000 else: assert False, 'Do not support dataset : {}'.format(args.dataset) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print_log("=> creating model '{}'".format(args.arch), log) # Init model, criterion, and optimizer # print(len(signature(models.__dict__[args.arch]).parameters)) model_param_dict = signature(models.__dict__[args.arch]).parameters # print(signature(models.__dict__[args.arch]).parameters) if ('AD_sigma' in model_param_dict) and ('input_grain_size' in model_param_dict): net = models.__dict__[args.arch]( num_classes, args.AD_sigma, args.DA_sigma, args.input_grain_size, args.input_num_bits, args.input_M2D, args.res_grain_size, args.res_num_bits, args.res_M2D, args.output_grain_size, args.output_num_bits, args.output_M2D, args.save_path) elif 'input_grain_size' in model_param_dict: net = models.__dict__[args.arch]( num_classes, args.input_grain_size, args.input_num_bits, args.input_M2D, args.res_grain_size, args.res_num_bits, args.res_M2D, args.output_grain_size, args.output_num_bits, args.output_M2D, args.save_path) elif 'AD_sigma' in model_param_dict: net = models.__dict__[args.arch](num_classes, args.AD_sigma, args.DA_sigma) else: net = models.__dict__[args.arch](num_classes) print_log("=> network :\n {}".format(net), log) if args.use_cuda: if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss() # separate the parameters thus param groups can be updated by different optimizer all_param = [ param for name, param in net.named_parameters() if not 'step_size' in name ] step_param = [ param for name, param in net.named_parameters() if 'step_size' in name ] if args.optimizer == "SGD": print("using SGD as optimizer") optimizer = torch.optim.SGD(all_param, lr=state['learning_rate'], momentum=state['momentum'], weight_decay=state['decay'], nesterov=True) elif args.optimizer == "Adam": print("using Adam as optimizer") optimizer = torch.optim.Adam(all_param, lr=state['learning_rate'], weight_decay=state['decay']) elif args.optimizer == "YF": print("using YellowFin as optimizer") optimizer = YFOptimizer(filter(lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'], mu=state['momentum'], weight_decay=state['decay']) elif args.optimizer == "RMSprop": print("using RMSprop as optimizer") optimizer = torch.optim.RMSprop(filter( lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'], alpha=0.99, eps=1e-08, weight_decay=0, momentum=0) if args.use_cuda: net.cuda() criterion.cuda() recorder = RecorderMeter(args.epochs) # count number of epoches # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print_log("=> loading checkpoint '{}'".format(args.resume), log) checkpoint = torch.load(args.resume) if not (args.fine_tune): args.start_epoch = checkpoint['epoch'] recorder = checkpoint['recorder'] optimizer.load_state_dict(checkpoint['optimizer'], strict=False) state_tmp = net.state_dict() if 'state_dict' in checkpoint.keys(): state_tmp.update(checkpoint['state_dict']) else: state_tmp.update(checkpoint) net.load_state_dict(state_tmp) # net.load_state_dict(checkpoint['state_dict']) print_log( "=> loaded checkpoint '{}' (epoch {})".format( args.resume, args.start_epoch), log) else: print_log("=> no checkpoint found at '{}'".format(args.resume), log) else: print_log( "=> do not use any checkpoint for {} model".format(args.arch), log) # update the step_size once the model is loaded. This is used for quantization. for m in net.modules(): if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear): # simple step size update based on the pretrained model or weight init m.__reset_stepsize__() # block for quantizer optimization if args.optimize_step: optimizer_quan = torch.optim.SGD(step_param, lr=0.01, momentum=0.9, weight_decay=0, nesterov=True) for m in net.modules(): if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear): for i in range( 300 ): # runs 200 iterations to reduce quantization error optimizer_quan.zero_grad() weight_quan = quantize(m.weight, m.step_size, m.half_lvls) * m.step_size loss_quan = F.mse_loss(weight_quan, m.weight, reduction='mean') loss_quan.backward() optimizer_quan.step() for m in net.modules(): if isinstance(m, quan_Conv2d): print(m.step_size.data.item(), (m.step_size.detach() * m.half_lvls).item(), m.weight.max().item()) # block for weight reset if args.reset_weight: for m in net.modules(): if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear): m.__reset_weight__() # print(m.weight) # attacker = BFA(criterion, args.k_top) # net_clean = copy.deepcopy(net) # # if args.enable_bfa: # perform_attack(attacker, net, net_clean, train_loader, test_loader, # args.n_iter, log, writer) # return if args.evaluate: validate(test_loader, net, criterion, log) return # Main loop start_time = time.time() epoch_time = AverageMeter() for epoch in range(args.start_epoch, args.epochs): current_learning_rate, current_momentum = adjust_learning_rate( optimizer, epoch, args.gammas, args.schedule) # Display simulation time need_hour, need_mins, need_secs = convert_secs2time( epoch_time.avg * (args.epochs - epoch)) need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format( need_hour, need_mins, need_secs) print_log( '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'. format(time_string(), epoch, args.epochs, need_time, current_learning_rate, current_momentum) + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format( recorder.max_accuracy(False), 100 - recorder.max_accuracy(False)), log) # # ============ TensorBoard logging ============# # # we show the model param initialization to give a intuition when we do the fine tuning # for name, param in net.named_parameters(): # name = name.replace('.', '/') # if "delta_th" not in name: # writer.add_histogram(name, param.clone().cpu().detach().numpy(), epoch) # # ============ TensorBoard logging ============# # train for one epoch train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log) # evaluate on validation set val_acc, val_los = validate(test_loader, net, criterion, log) is_best = val_acc > recorder.max_accuracy(istrain=False) recorder.update(epoch, train_los, train_acc, val_los, val_acc) if args.model_only: checkpoint_state = {'state_dict': net.state_dict()} else: checkpoint_state = { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': net.state_dict(), 'recorder': recorder, 'optimizer': optimizer.state_dict(), } save_checkpoint(checkpoint_state, is_best, args.save_path, 'checkpoint.pth.tar', log) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() recorder.plot_curve(os.path.join(args.save_path, 'curve.png')) # save addition accuracy log for plotting accuracy_logger(base_dir=args.save_path, epoch=epoch, train_accuracy=train_acc, test_accuracy=val_acc) log.close()