Esempio n. 1
0
    def __init__(self, flags, agents, gpu_nr, memory_fraction, weights, lock,
                 scale, staleness_aware, epoch):
        super().__init__(flags, agents, gpu_nr, memory_fraction, weights, lock)
        from tuner_utils.yellowfin import YFOptimizer
        self.times = []
        self.flags = flags
        self.agents = agents
        self.gpu_nr = gpu_nr
        self.lock = lock
        self.scale = scale
        self.staleness_aware = staleness_aware
        self.nesterov = True
        self.lr_dict = iter([(1, 0.4), (82, 0.04), (123, 0.004),
                             (165, 0.0001)])
        self.next_switch = next(self.lr_dict)
        self.learning_rate = 0
        self.batchsize = self.flags.batch_size
        self.epochs = epoch

        if self.flags.drop_remainder:
            self.iterations_in_epoch = self.agents * math.floor(
                self.flags.train_set_size / (self.flags.batch_size))
        else:
            self.iterations_in_epoch = self.agents * math.ceil(
                self.flags.train_set_size / (self.flags.batch_size))

        logger.state("iterations in epoch", self.iterations_in_epoch)
        if 'cuda' in self.flags.device:
            self.device = torch.device('cuda:' + str(self.gpu_nr))
        else:
            pass
        self.device = torch.device('cpu')

        self.weight = []
        self.buf = []
        for w in weights:
            t = torch.tensor(w, device=self.device, requires_grad=True)
            t.grad = torch.tensor(w, device=self.device)
            self.weight.append(t)

        self.optimizer = YFOptimizer(self.weight, lr=1, mu=0.9)

        self.optimizer.zero_grad()
Esempio n. 2
0
class YellowFin(UpdadteFunction):
    """
    This class is a wrapper for the YellowFin implementation found at: https://github.com/AnonRepository/YellowFin_Pytorch
    and described in this paper: https://arxiv.org/pdf/1706.03471.pdf
    Because the code is not distributed under any licence, we cannot include it in the project.
    to make this project work download the tuner_utils folder from the github page and add it to the project folder.
    Allso has to be patched to work with newer versions of pytorch. (add .numpy to line 275 and 283)
    """
    def __init__(self, flags, agents, gpu_nr, memory_fraction, weights, lock,
                 scale, staleness_aware, epoch):
        super().__init__(flags, agents, gpu_nr, memory_fraction, weights, lock)
        from tuner_utils.yellowfin import YFOptimizer
        self.times = []
        self.flags = flags
        self.agents = agents
        self.gpu_nr = gpu_nr
        self.lock = lock
        self.scale = scale
        self.staleness_aware = staleness_aware
        self.nesterov = True
        self.lr_dict = iter([(1, 0.4), (82, 0.04), (123, 0.004),
                             (165, 0.0001)])
        self.next_switch = next(self.lr_dict)
        self.learning_rate = 0
        self.batchsize = self.flags.batch_size
        self.epochs = epoch

        if self.flags.drop_remainder:
            self.iterations_in_epoch = self.agents * math.floor(
                self.flags.train_set_size / (self.flags.batch_size))
        else:
            self.iterations_in_epoch = self.agents * math.ceil(
                self.flags.train_set_size / (self.flags.batch_size))

        logger.state("iterations in epoch", self.iterations_in_epoch)
        if 'cuda' in self.flags.device:
            self.device = torch.device('cuda:' + str(self.gpu_nr))
        else:
            pass
        self.device = torch.device('cpu')

        self.weight = []
        self.buf = []
        for w in weights:
            t = torch.tensor(w, device=self.device, requires_grad=True)
            t.grad = torch.tensor(w, device=self.device)
            self.weight.append(t)

        self.optimizer = YFOptimizer(self.weight, lr=1, mu=0.9)

        self.optimizer.zero_grad()

    def learning_rate_func(self, epoch, update):
        if self.next_switch is not None and epoch >= self.next_switch[0]:
            logger.state("change of learning rate", epoch, self.learning_rate,
                         self.next_switch[1])
            self.learning_rate = self.next_switch[1]

            logger.state("lr", self.learning_rate)
            try:
                self.next_switch = next(self.lr_dict)
            except StopIteration:
                self.next_switch = None
        if self.epochs >= epoch:
            # print("warm up optimizer", flush=True)
            return self.learning_rate
        else:
            return self.learning_rate / self.scale

    def __call__(self, weights, update, gradients, staleness, epoch):
        """
        :param weights: Copy of the model weights
        :param update: Curent update
        :param gradients: List of gradients
        :param staleness: Staleness of each gradient
        :param epoch: Current epoch
        """

        lrr = self.learning_rate_func(epoch, update)

        lr = torch.tensor(-lrr * len(gradients),
                          dtype=torch.float,
                          device=self.device)
        self.optimizer.set_lr_factor(lrr)

        start_time = time.time()
        if not self.staleness_aware:
            gradient = gradients
        else:
            gradient = [np.divide(g, s) for g, s in zip(gradients, staleness)]
        grad = np.mean(gradient, axis=0)

        i = 0
        for elem in self.optimizer._optimizer.param_groups:
            for p in elem['params']:
                p.grad.data.copy_(torch.from_numpy(grad[i]), non_blocking=True)
                i += 1

        self.optimizer.step()

        c = 0
        wei = []
        with self.lock:
            for elem in self.optimizer._optimizer.param_groups:
                for p in elem['params']:

                    wei.append(p.data.numpy())
                    c += 1
        end_time = time.time()
        if self.flags.time_program:
            self.times.append(end_time - start_time)
        return wei

    def __del__(self):
        if self.flags.time_program:
            if self.times != []:
                t = np.mean(self.times, axis=0)
                logger.state("Optimizer took", t, flush=True)

    def close(self):
        pass
Esempio n. 3
0
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log')
    # logger = Logger(tb_path)
    writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                transform=train_transform,
                                download=True)
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
        else:
            net = torch.nn.DataParallel(net, device_ids=[0])

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # params without threshold
    all_param = [
        param for name, param in net.named_parameters()
        if not 'delta_th' in name
    ]

    th_param = [
        param for name, param in net.named_parameters() if 'delta_th' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)
        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=state['decay'],
                                       nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(all_param,
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=0,
                                       nesterov=True)

    elif args.optimizer == "YF":
        print("using YellowFin as optimizer")
        optimizer = YFOptimizer(filter(lambda param: param.requires_grad,
                                       net.parameters()),
                                lr=state['learning_rate'],
                                mu=state['momentum'],
                                weight_decay=state['decay'])
    # optimizer = YFOptimizer( filter(lambda param: param.requires_grad, net.parameters()) )
    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    # Right after the pretrained model is loaded:
    '''
    when model is loaded with the pre-trained model, the original
    initialized threshold are not correct anymore, which might be clipped
    by the hard-tanh function.
    '''
    for name, module in net.named_modules():
        name = name.replace('.', '/')
        class_name = str(module.__class__).split('.')[-1].split("'")[0]
        if "quanConv2d" in class_name or "quanLinear" in class_name:
            module.delta_th.data = module.weight.abs().max(
            ) * module.init_factor.cuda()

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # set the graident register hook to modify the gradient (gradient clipping)
    for name, param in net.named_parameters():
        if "delta_th" in name:
            # if "delta_th" in name and 'classifier' in name:
            # based on previous experiment, the clamp interval would better range between 0.001
            param.register_hook(lambda grad: grad.clamp(min=-0.001, max=0.001))

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer_th, epoch, args.gammas, args.schedule)

        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # ============ TensorBoard logging ============#
        # we show the model param initialization to give a intuition when we do the fine tuning

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            if "delta_th" not in name:
                writer.add_histogram(name, param.cpu().detach().numpy(), epoch)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                sparsity = Sparsity_check(module)
                writer.add_scalar(name + '/sparsity/', sparsity, epoch)
                # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        # ============ TensorBoard logging ============#

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     optimizer_th, epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)
        is_best = val_acc >= recorder.max_accuracy(False)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(name + '/grad',
                                 param.grad.cpu().detach().numpy(), epoch + 1)

        # for name, module in net.named_modules():
        #     name = name.replace('.', '/')
        #     class_name = str(module.__class__).split('.')[-1].split("'")[0]
        #     if "quanConv2d" in class_name or "quanLinear" in class_name:
        #         sparsity = Sparsity_check(module)
        #         writer.add_scalar(name + '/sparsity/', sparsity, epoch + 1)
        #         # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                if module.delta_th.data is not None:
                    if module.delta_th.dim(
                    ) == 0:  # zero-dimension tensor (scalar) not iterable
                        writer.add_scalar(name + '/delta/',
                                          module.delta_th.detach(), epoch + 1)
                    else:
                        for idx, delta in enumerate(module.delta_th.detach()):
                            writer.add_scalar(
                                name + '/delta/' + '{}'.format(idx), delta,
                                epoch + 1)

        writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
Esempio n. 4
0
def init_model(
        FLAGS,
        logger,
        initial_embeddings,
        vocab_size,
        num_classes,
        data_manager,
        logfile_header=None):
    # Choose model.
    logger.Log("Building model.")
    if FLAGS.model_type == "CBOW":
        build_model = spinn.cbow.build_model
    elif FLAGS.model_type == "RNN":
        build_model = spinn.plain_rnn.build_model
    elif FLAGS.model_type == "SPINN":
        build_model = spinn.spinn_core_model.build_model
    elif FLAGS.model_type == "RLSPINN":
        build_model = spinn.rl_spinn.build_model
    elif FLAGS.model_type == "Pyramid":
        build_model = spinn.pyramid.build_model
    elif FLAGS.model_type == "ChoiPyramid":
        build_model = spinn.choi_pyramid.build_model
    else:
        raise NotImplementedError

    # Input Encoder.
    context_args = Args()
    context_args.reshape_input = lambda x, batch_size, seq_length: x
    context_args.reshape_context = lambda x, batch_size, seq_length: x
    context_args.input_dim = FLAGS.word_embedding_dim

    if FLAGS.encode == "projection":
        encoder = Linear()(FLAGS.word_embedding_dim, FLAGS.model_dim)
        context_args.input_dim = FLAGS.model_dim
    elif FLAGS.encode == "gru":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = EncodeGRU(FLAGS.word_embedding_dim, FLAGS.model_dim,
                            num_layers=FLAGS.encode_num_layers,
                            bidirectional=FLAGS.encode_bidirectional,
                            reverse=FLAGS.encode_reverse)
    elif FLAGS.encode == "attn":
        context_args.reshape_input = lambda x, batch_size, seq_length: x.view(
            batch_size, seq_length, -1)
        context_args.reshape_context = lambda x, batch_size, seq_length: x.view(
            batch_size * seq_length, -1)
        context_args.input_dim = FLAGS.model_dim
        encoder = IntraAttention(FLAGS.word_embedding_dim, FLAGS.model_dim)
    elif FLAGS.encode == "pass":
        def encoder(x): return x
    else:
        raise NotImplementedError

    context_args.encoder = encoder

    # Composition Function.
    composition_args = Args()
    composition_args.lateral_tracking = FLAGS.lateral_tracking
    composition_args.tracking_ln = FLAGS.tracking_ln
    composition_args.use_tracking_in_composition = FLAGS.use_tracking_in_composition
    composition_args.size = FLAGS.model_dim
    composition_args.tracker_size = FLAGS.tracking_lstm_hidden_dim
    composition_args.use_internal_parser = FLAGS.use_internal_parser
    composition_args.transition_weight = FLAGS.transition_weight
    composition_args.wrap_items = lambda x: torch.cat(x, 0)
    composition_args.extract_h = lambda x: x
    composition_args.extract_c = None

    composition_args.detach = FLAGS.transition_detach
    composition_args.evolution = FLAGS.evolution

    if FLAGS.reduce == "treelstm":
        assert FLAGS.model_dim % 2 == 0, 'model_dim must be an even number.'
        if FLAGS.model_dim != FLAGS.word_embedding_dim:
            print('If you are setting different hidden layer and word '
                  'embedding sizes, make sure you specify an encoder')
        composition_args.wrap_items = lambda x: bundle(x)
        composition_args.extract_h = lambda x: x.h
        composition_args.extract_c = lambda x: x.c
        composition_args.size = FLAGS.model_dim / 2
        composition = ReduceTreeLSTM(FLAGS.model_dim / 2,
                                     tracker_size=FLAGS.tracking_lstm_hidden_dim,
                                     use_tracking_in_composition=FLAGS.use_tracking_in_composition,
                                     composition_ln=FLAGS.composition_ln)
    elif FLAGS.reduce == "tanh":
        class ReduceTanh(nn.Module):
            def forward(self, lefts, rights, tracking=None):
                batch_size = len(lefts)
                ret = torch.cat(lefts, 0) + F.tanh(torch.cat(rights, 0))
                return torch.chunk(ret, batch_size, 0)
        composition = ReduceTanh()
    elif FLAGS.reduce == "treegru":
        composition = ReduceTreeGRU(FLAGS.model_dim,
                                    FLAGS.tracking_lstm_hidden_dim,
                                    FLAGS.use_tracking_in_composition)
    else:
        raise NotImplementedError

    composition_args.composition = composition

    model = build_model(data_manager, initial_embeddings, vocab_size,
                        num_classes, FLAGS, context_args, composition_args)

    # Build optimizer.
    if FLAGS.optimizer_type == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate,
                               betas=(0.9, 0.999), eps=1e-08)
    elif FLAGS.optimizer_type == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08)
    elif FLAGS.optimizer_type == "YellowFin":
        optimizer = YFOptimizer(model.parameters(), lr=FLAGS.learning_rate)
        if FLAGS.actively_decay_learning_rate:
            logger.Log(
                "WARNING: Ignoring actively_decay_learning_rate and learning_rate_decay_per_10k_steps. Not implemeted for YellowFin.")
    else:
        raise NotImplementedError

    # Build trainer.
    if FLAGS.evolution:
        trainer = ModelTrainer_ES(model, optimizer)
    else:
        trainer = ModelTrainer(model, optimizer)

    # Print model size.
    logger.Log("Architecture: {}".format(model))
    if logfile_header:
        logfile_header.model_architecture = str(model)
    total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()])
    logger.Log("Total params: {}".format(total_params))
    if logfile_header:
        logfile_header.total_params = int(total_params)

    return model, optimizer, trainer
Esempio n. 5
0
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    if not os.path.isdir(os.path.join(args.save_path, 'saved_tensors')):
        os.makedirs(os.path.join(args.save_path, 'saved_tensors'))
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log')
    # logger = Logger(tb_path)
    # writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307, ),
                                                         (0.3081, ))
                                ]))
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1307, ), (0.3081, ))
                               ]))
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    # print(len(signature(models.__dict__[args.arch]).parameters))
    model_param_dict = signature(models.__dict__[args.arch]).parameters
    # print(signature(models.__dict__[args.arch]).parameters)
    if ('AD_sigma' in model_param_dict) and ('input_grain_size'
                                             in model_param_dict):
        net = models.__dict__[args.arch](
            num_classes, args.AD_sigma, args.DA_sigma, args.input_grain_size,
            args.input_num_bits, args.input_M2D, args.res_grain_size,
            args.res_num_bits, args.res_M2D, args.output_grain_size,
            args.output_num_bits, args.output_M2D, args.save_path)
    elif 'input_grain_size' in model_param_dict:
        net = models.__dict__[args.arch](
            num_classes, args.input_grain_size, args.input_num_bits,
            args.input_M2D, args.res_grain_size, args.res_num_bits,
            args.res_M2D, args.output_grain_size, args.output_num_bits,
            args.output_M2D, args.save_path)
    elif 'AD_sigma' in model_param_dict:
        net = models.__dict__[args.arch](num_classes, args.AD_sigma,
                                         args.DA_sigma)
    else:
        net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # separate the parameters thus param groups can be updated by different optimizer
    all_param = [
        param for name, param in net.named_parameters()
        if not 'step_size' in name
    ]

    step_param = [
        param for name, param in net.named_parameters() if 'step_size' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(all_param,
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

    elif args.optimizer == "YF":
        print("using YellowFin as optimizer")
        optimizer = YFOptimizer(filter(lambda param: param.requires_grad,
                                       net.parameters()),
                                lr=state['learning_rate'],
                                mu=state['momentum'],
                                weight_decay=state['decay'])

    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'],
                                          strict=False)

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)
            # net.load_state_dict(checkpoint['state_dict'])

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)
    # update the step_size once the model is loaded. This is used for quantization.
    for m in net.modules():
        if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
            # simple step size update based on the pretrained model or weight init
            m.__reset_stepsize__()
    # block for quantizer optimization
    if args.optimize_step:
        optimizer_quan = torch.optim.SGD(step_param,
                                         lr=0.01,
                                         momentum=0.9,
                                         weight_decay=0,
                                         nesterov=True)

        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                for i in range(
                        300
                ):  # runs 200 iterations to reduce quantization error
                    optimizer_quan.zero_grad()
                    weight_quan = quantize(m.weight, m.step_size,
                                           m.half_lvls) * m.step_size
                    loss_quan = F.mse_loss(weight_quan,
                                           m.weight,
                                           reduction='mean')
                    loss_quan.backward()
                    optimizer_quan.step()

        for m in net.modules():
            if isinstance(m, quan_Conv2d):
                print(m.step_size.data.item(),
                      (m.step_size.detach() * m.half_lvls).item(),
                      m.weight.max().item())
    # block for weight reset
    if args.reset_weight:
        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                m.__reset_weight__()
                # print(m.weight)

    # attacker = BFA(criterion, args.k_top)
    # net_clean = copy.deepcopy(net)
    #
    # if args.enable_bfa:
    #     perform_attack(attacker, net, net_clean, train_loader, test_loader,
    #                    args.n_iter, log, writer)
    #     return

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.
            format(time_string(), epoch, args.epochs, need_time,
                   current_learning_rate, current_momentum) +
            ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(
                recorder.max_accuracy(False),
                100 - recorder.max_accuracy(False)), log)

        # # ============ TensorBoard logging ============#
        # # we show the model param initialization to give a intuition when we do the fine tuning

        # for name, param in net.named_parameters():
        #     name = name.replace('.', '/')
        #     if "delta_th" not in name:
        #         writer.add_histogram(name, param.clone().cpu().detach().numpy(), epoch)

        # # ============ TensorBoard logging ============#

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)

        is_best = val_acc > recorder.max_accuracy(istrain=False)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict()}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

    log.close()