Ejemplo n.º 1
0
    def __init__(self):
        super(DartsWrapper, self).__init__()

        args = AttrDict(self.args.__dict__)
        self.args = args
        self.seed = args.seed

        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = False
        cudnn.enabled = True
        cudnn.deterministic = True
        torch.cuda.manual_seed_all(args.seed)

        self.train_queue, self.valid_queue, _, _ = super(
            DartsWrapper, self).get_train_val_loaders()
        setattr(self.train_queue, 'num_workers', 0)
        setattr(self.train_queue, 'worker_init_fn', np.random.seed(args.seed))
        setattr(self.valid_queue, 'num_workers', 0)
        setattr(self.valid_queue, 'worker_init_fn', np.random.seed(args.seed))

        self.train_iter = iter(self.train_queue)
        self.valid_iter = iter(self.valid_queue)

        self.steps = 0
        self.epochs = 0
        self.total_loss = 0
        self.start_time = time.time()
        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()
        self.criterion = criterion

        self.primitives = spaces_dict[args.space]

        model = Network(args.init_channels,
                        args.n_classes,
                        args.layers,
                        self.criterion,
                        self.primitives,
                        steps=args.nodes)

        model = model.cuda()
        self.model = model

        optimizer = torch.optim.SGD(self.model.parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.optimizer = optimizer

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
Ejemplo n.º 2
0
def main(primitives):
    if not torch.cuda.is_available() or args.disable_cuda:
        logging.info('no gpu device available or disabling cuda')

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.disable_cuda:
        torch.cuda.set_device(args.gpu)
        logging.info('gpu device = %d' % args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.cuda.manual_seed(args.seed)

    criterion = nn.CrossEntropyLoss()
    if not args.disable_cuda:
        criterion = criterion.cuda()

    model_init = Network(args.init_channels,
                         args.n_classes,
                         args.layers,
                         criterion,
                         primitives,
                         steps=args.nodes,
                         args=args)
    if not args.disable_cuda:
        model_init = model_init.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init))

    optimizer_init = torch.optim.SGD(model_init.parameters(),
                                     args.learning_rate,
                                     momentum=args.momentum,
                                     weight_decay=args.weight_decay)

    architect_init = Architect(model_init, args)

    scheduler_init = CosineAnnealingLR(optimizer_init,
                                       float(args.epochs),
                                       eta_min=args.learning_rate_min)

    analyser = Analyzer(args, model_init)
    la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                  args.epochs)

    train_queue, valid_queue, train_transform, valid_transform = helper.get_train_val_loaders(
    )

    def valid_generator():
        while True:
            for x, t in valid_queue:
                yield x, t

    valid_gen = valid_generator()

    for epoch in range(args.ev_start_epoch - 1, args.epochs):
        beta_decay_scheduler.step(epoch)
        logging.info("EPOCH %d SKIP BETA DECAY RATE: %e", epoch,
                     beta_decay_scheduler.decay_rate)
        if (epoch % args.report_freq_hessian == 0) or (epoch
                                                       == (args.epochs - 1)):
            lr = utils.load_checkpoint(model_init, optimizer_init, None,
                                       architect_init, args.save, la_tracker,
                                       epoch, args.task_id)
            logging.info("Loaded %d-th checkpoint." % epoch)

            if args.test_infer:
                valid_acc, valid_obj = infer(valid_queue, model_init,
                                             criterion)
                logging.info('valid_acc %f', valid_acc)

            if args.compute_hessian:
                input, target = next(iter(train_queue))
                input = Variable(input, requires_grad=False)
                target = Variable(target, requires_grad=False)
                input_search, target_search = next(
                    valid_gen)  #next(iter(valid_queue))
                input_search = Variable(input_search, requires_grad=False)
                target_search = Variable(target_search, requires_grad=False)

                if not args.disable_cuda:
                    input = input.cuda()
                    target = target.cuda(async=True)
                    input_search = input_search.cuda()
                    target_search = target_search.cuda(async=True)

                if not args.debug:
                    H = analyser.compute_Hw(input, target, input_search,
                                            target_search, lr, optimizer_init,
                                            False)
                    g = analyser.compute_dw(input, target, input_search,
                                            target_search, lr, optimizer_init,
                                            False)
                    g = torch.cat([x.view(-1) for x in g])

                    state = {
                        'epoch': epoch,
                        'H': H.cpu().data.numpy().tolist(),
                        'g': g.cpu().data.numpy().tolist(),
                        #'g_train': float(grad_norm),
                        #'eig_train': eigenvalue,
                    }

                    with codecs.open(os.path.join(
                            args.save,
                            'derivatives_{}.json'.format(args.task_id)),
                                     'a',
                                     encoding='utf-8') as file:
                        json.dump(state, file, separators=(',', ':'))
                        file.write('\n')

                    # early stopping
                    ev = max(LA.eigvals(H.cpu().data.numpy()))
                    logging.info('CURRENT EV: %f', ev)
Ejemplo n.º 3
0
def main(primitives):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model_init = Network(args.init_channels,
                         args.n_classes,
                         args.layers,
                         criterion,
                         primitives,
                         steps=args.nodes)
    model_init = model_init.cuda()
    #logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init))
    optimizer_init = torch.optim.SGD(model_init.parameters(),
                                     args.learning_rate,
                                     momentum=args.momentum,
                                     weight_decay=args.weight_decay)
    architect_init = Architect(model_init, args)

    scheduler_init = CosineAnnealingLR(optimizer_init,
                                       float(args.epochs),
                                       eta_min=args.learning_rate_min)

    analyser_init = Analyzer(args, model_init)
    la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                  args.epochs)

    if args.resume:
        if os.path.isfile(args.resume_file):
            print("=> loading checkpoint '{}'".format(args.resume_file))
            checkpoint = torch.load(args.resume_file)
            start_epoch = 27
            print('start_epoch', start_epoch)
            model_init.load_state_dict(checkpoint['state_dict'])
            model_init.alphas_normal.data = checkpoint['alphas_normal']
            model_init.alphas_reduce.data = checkpoint['alphas_reduce']
            model_init = model_init.cuda()
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(model_init))
            optimizer_init.load_state_dict(checkpoint['optimizer'])
            architect_init.optimizer.load_state_dict(
                checkpoint['arch_optimizer'])
            scheduler_init = CosineAnnealingLR(optimizer_init,
                                               float(args.epochs),
                                               eta_min=args.learning_rate_min)
            analyser_init = Analyzer(args, model_init)
            la_tracker = utils.EVLocalAvg(args.window,
                                          args.report_freq_hessian,
                                          args.epochs)
            la_tracker.ev = checkpoint['ev']
            la_tracker.ev_local_avg = checkpoint['ev_local_avg']
            la_tracker.genotypes = checkpoint['genotypes']
            la_tracker.la_epochs = checkpoint['la_epochs']
            la_tracker.la_start_idx = checkpoint['la_start_idx']
            la_tracker.la_end_idx = checkpoint['la_end_idx']
            lr = checkpoint['lr']

    train_queue, valid_queue, train_transform, valid_transform = helper.get_train_val_loaders(
    )

    errors_dict = {
        'train_acc': [],
        'train_loss': [],
        'valid_acc': [],
        'valid_loss': []
    }

    #for epoch in range(args.epochs):
    def train_epochs(epochs_to_train,
                     iteration,
                     args=args,
                     model=model_init,
                     optimizer=optimizer_init,
                     scheduler=scheduler_init,
                     train_queue=train_queue,
                     valid_queue=valid_queue,
                     train_transform=train_transform,
                     valid_transform=valid_transform,
                     architect=architect_init,
                     criterion=criterion,
                     primitives=primitives,
                     analyser=analyser_init,
                     la_tracker=la_tracker,
                     errors_dict=errors_dict,
                     start_epoch=-1):

        logging.info('STARTING ITERATION: %d', iteration)
        logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1)

        la_tracker.stop_search = False

        if epochs_to_train - start_epoch - 1 <= 0:
            return model.genotype(), -1
        for epoch in range(start_epoch + 1, epochs_to_train):
            # set the epoch to the right one
            #epoch += args.epochs - epochs_to_train

            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            if args.drop_path_prob != 0:
                model.drop_path_prob = args.drop_path_prob * epoch / (
                    args.epochs - 1)
                train_transform.transforms[
                    -1].cutout_prob = args.cutout_prob * epoch / (args.epochs -
                                                                  1)
                logging.info('epoch %d lr %e drop_prob %e cutout_prob %e',
                             epoch, lr, model.drop_path_prob,
                             train_transform.transforms[-1].cutout_prob)
            else:
                logging.info('epoch %d lr %e', epoch, lr)

            # training
            train_acc, train_obj = train(epoch, primitives, train_queue,
                                         valid_queue, model, architect,
                                         criterion, optimizer, lr, analyser,
                                         la_tracker, iteration)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            # update the errors dictionary
            errors_dict['train_acc'].append(100 - train_acc)
            errors_dict['train_loss'].append(train_obj)
            errors_dict['valid_acc'].append(100 - valid_acc)
            errors_dict['valid_loss'].append(valid_obj)

            genotype = model.genotype()

            logging.info('genotype = %s', genotype)

            print(F.softmax(model.alphas_normal, dim=-1))
            print(F.softmax(model.alphas_reduce, dim=-1))

            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'alphas_normal': model.alphas_normal.data,
                'alphas_reduce': model.alphas_reduce.data,
                'arch_optimizer': architect.optimizer.state_dict(),
                'lr': lr,
                'ev': la_tracker.ev,
                'ev_local_avg': la_tracker.ev_local_avg,
                'genotypes': la_tracker.genotypes,
                'la_epochs': la_tracker.la_epochs,
                'la_start_idx': la_tracker.la_start_idx,
                'la_end_idx': la_tracker.la_end_idx,
                #'scheduler': scheduler.state_dict(),
            }

            utils.save_checkpoint(state, False, args.save, epoch, args.task_id)

            if not args.compute_hessian:
                ev = -1
            else:
                ev = la_tracker.ev[-1]
            params = {
                'iteration': iteration,
                'epoch': epoch,
                'wd': args.weight_decay,
                'ev': ev,
            }

            schedule_of_params.append(params)

            # limit the number of iterations based on the maximum regularization
            # value predefined by the user
            final_iteration = round(
                np.log(args.max_weight_decay) / np.log(args.weight_decay),
                1) == 1.  ##lr decay到一定程度就停止

            if la_tracker.stop_search and not final_iteration:
                if args.early_stop == 1:
                    # set the following to the values they had at stop_epoch
                    errors_dict['valid_acc'] = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    genotype = la_tracker.stop_genotype
                    valid_acc = 100 - errors_dict['valid_acc'][
                        la_tracker.stop_epoch]
                    logging.info(
                        'Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('Validation accuracy at stop epoch: %f',
                                 valid_acc)
                    logging.info('Genotype at stop epoch: %s', genotype)
                    break

                elif args.early_stop == 2:
                    # simulate early stopping and continue search afterwards
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    logging.info(
                        '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('(SIM) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(SIM) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    with open(
                            os.path.join(args.save,
                                         'arch_early_{}'.format(args.task_id)),
                            'w') as file:
                        file.write(str(simulated_genotype))

                    utils.write_yaml_results(args,
                                             'early_' + args.results_file_arch,
                                             str(simulated_genotype))
                    utils.write_yaml_results(args, 'early_stop_epochs',
                                             la_tracker.stop_epoch)

                    args.early_stop = 0

                elif args.early_stop == 3:
                    # adjust regularization
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    stop_epoch = la_tracker.stop_epoch
                    start_again_epoch = stop_epoch - args.extra_rollback_epochs
                    logging.info(
                        '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)',
                        stop_epoch, epoch)
                    logging.info('(ADA) Rolling back to epoch %d',
                                 start_again_epoch)
                    logging.info(
                        '(ADA) Restoring model parameters and continuing for %d epochs',
                        epochs_to_train - start_again_epoch - 1)

                    if iteration == 1:
                        logging.info(
                            '(ADA) Saving the architecture at the early stop epoch and '
                            'continuing with the adaptive regularization strategy'
                        )
                        utils.write_yaml_results(
                            args, 'early_' + args.results_file_arch,
                            str(simulated_genotype))

                    del model
                    del architect
                    del optimizer
                    del scheduler
                    del analyser

                    model_new = Network(args.init_channels,
                                        args.n_classes,
                                        args.layers,
                                        criterion,
                                        primitives,
                                        steps=args.nodes)
                    model_new = model_new.cuda()

                    optimizer_new = torch.optim.SGD(
                        model_new.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

                    architect_new = Architect(model_new, args)

                    analyser_new = Analyzer(args, model_new)

                    la_tracker = utils.EVLocalAvg(args.window,
                                                  args.report_freq_hessian,
                                                  args.epochs)

                    lr = utils.load_checkpoint(model_new, optimizer_new, None,
                                               architect_new, args.save,
                                               la_tracker, start_again_epoch,
                                               args.task_id)

                    args.weight_decay *= args.mul_factor
                    for param_group in optimizer_new.param_groups:
                        param_group['weight_decay'] = args.weight_decay

                    scheduler_new = CosineAnnealingLR(
                        optimizer_new,
                        float(args.epochs),
                        eta_min=args.learning_rate_min)

                    logging.info('(ADA) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(ADA) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    logging.info(
                        '(ADA) Adjusting L2 regularization to the new value: %f',
                        args.weight_decay)

                    genotype, valid_acc = train_epochs(args.epochs,
                                                       iteration + 1,
                                                       model=model_new,
                                                       optimizer=optimizer_new,
                                                       architect=architect_new,
                                                       scheduler=scheduler_new,
                                                       analyser=analyser_new,
                                                       start_epoch=start_epoch)
                    args.early_stop = 0
                    break

        return genotype, valid_acc

    # call train_epochs recursively
    genotype, valid_acc = train_epochs(args.epochs, 1)

    with codecs.open(os.path.join(args.save,
                                  'errors_{}.json'.format(args.task_id)),
                     'w',
                     encoding='utf-8') as file:
        json.dump(errors_dict, file, separators=(',', ':'))

    with open(os.path.join(args.save, 'arch_{}'.format(args.task_id)),
              'w') as file:
        file.write(str(genotype))

    utils.write_yaml_results(args, args.results_file_arch, str(genotype))
    utils.write_yaml_results(args, args.results_file_perf, 100 - valid_acc)

    with open(
            os.path.join(args.save, 'schedule_{}.pickle'.format(args.task_id)),
            'ab') as file:
        pickle.dump(schedule_of_params, file, pickle.HIGHEST_PROTOCOL)
Ejemplo n.º 4
0
    def train_epochs(epochs_to_train,
                     iteration,
                     args=args,
                     model=model_init,
                     optimizer=optimizer_init,
                     scheduler=scheduler_init,
                     train_queue=train_queue,
                     valid_queue=valid_queue,
                     train_transform=train_transform,
                     valid_transform=valid_transform,
                     architect=architect_init,
                     criterion=criterion,
                     primitives=primitives,
                     analyser=analyser_init,
                     la_tracker=la_tracker,
                     errors_dict=errors_dict,
                     start_epoch=-1):

        logging.info('STARTING ITERATION: %d', iteration)
        logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1)

        la_tracker.stop_search = False

        if epochs_to_train - start_epoch - 1 <= 0:
            return model.genotype(), -1
        for epoch in range(start_epoch + 1, epochs_to_train):
            # set the epoch to the right one
            #epoch += args.epochs - epochs_to_train

            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            if args.drop_path_prob != 0:
                model.drop_path_prob = args.drop_path_prob * epoch / (
                    args.epochs - 1)
                train_transform.transforms[
                    -1].cutout_prob = args.cutout_prob * epoch / (args.epochs -
                                                                  1)
                logging.info('epoch %d lr %e drop_prob %e cutout_prob %e',
                             epoch, lr, model.drop_path_prob,
                             train_transform.transforms[-1].cutout_prob)
            else:
                logging.info('epoch %d lr %e', epoch, lr)

            # training
            train_acc, train_obj = train(epoch, primitives, train_queue,
                                         valid_queue, model, architect,
                                         criterion, optimizer, lr, analyser,
                                         la_tracker, iteration)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            # update the errors dictionary
            errors_dict['train_acc'].append(100 - train_acc)
            errors_dict['train_loss'].append(train_obj)
            errors_dict['valid_acc'].append(100 - valid_acc)
            errors_dict['valid_loss'].append(valid_obj)

            genotype = model.genotype()

            logging.info('genotype = %s', genotype)

            print(F.softmax(model.alphas_normal, dim=-1))
            print(F.softmax(model.alphas_reduce, dim=-1))

            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'alphas_normal': model.alphas_normal.data,
                'alphas_reduce': model.alphas_reduce.data,
                'arch_optimizer': architect.optimizer.state_dict(),
                'lr': lr,
                'ev': la_tracker.ev,
                'ev_local_avg': la_tracker.ev_local_avg,
                'genotypes': la_tracker.genotypes,
                'la_epochs': la_tracker.la_epochs,
                'la_start_idx': la_tracker.la_start_idx,
                'la_end_idx': la_tracker.la_end_idx,
                #'scheduler': scheduler.state_dict(),
            }

            utils.save_checkpoint(state, False, args.save, epoch, args.task_id)

            if not args.compute_hessian:
                ev = -1
            else:
                ev = la_tracker.ev[-1]
            params = {
                'iteration': iteration,
                'epoch': epoch,
                'wd': args.weight_decay,
                'ev': ev,
            }

            schedule_of_params.append(params)

            # limit the number of iterations based on the maximum regularization
            # value predefined by the user
            final_iteration = round(
                np.log(args.max_weight_decay) / np.log(args.weight_decay),
                1) == 1.  ##lr decay到一定程度就停止

            if la_tracker.stop_search and not final_iteration:
                if args.early_stop == 1:
                    # set the following to the values they had at stop_epoch
                    errors_dict['valid_acc'] = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    genotype = la_tracker.stop_genotype
                    valid_acc = 100 - errors_dict['valid_acc'][
                        la_tracker.stop_epoch]
                    logging.info(
                        'Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('Validation accuracy at stop epoch: %f',
                                 valid_acc)
                    logging.info('Genotype at stop epoch: %s', genotype)
                    break

                elif args.early_stop == 2:
                    # simulate early stopping and continue search afterwards
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    logging.info(
                        '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('(SIM) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(SIM) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    with open(
                            os.path.join(args.save,
                                         'arch_early_{}'.format(args.task_id)),
                            'w') as file:
                        file.write(str(simulated_genotype))

                    utils.write_yaml_results(args,
                                             'early_' + args.results_file_arch,
                                             str(simulated_genotype))
                    utils.write_yaml_results(args, 'early_stop_epochs',
                                             la_tracker.stop_epoch)

                    args.early_stop = 0

                elif args.early_stop == 3:
                    # adjust regularization
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    stop_epoch = la_tracker.stop_epoch
                    start_again_epoch = stop_epoch - args.extra_rollback_epochs
                    logging.info(
                        '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)',
                        stop_epoch, epoch)
                    logging.info('(ADA) Rolling back to epoch %d',
                                 start_again_epoch)
                    logging.info(
                        '(ADA) Restoring model parameters and continuing for %d epochs',
                        epochs_to_train - start_again_epoch - 1)

                    if iteration == 1:
                        logging.info(
                            '(ADA) Saving the architecture at the early stop epoch and '
                            'continuing with the adaptive regularization strategy'
                        )
                        utils.write_yaml_results(
                            args, 'early_' + args.results_file_arch,
                            str(simulated_genotype))

                    del model
                    del architect
                    del optimizer
                    del scheduler
                    del analyser

                    model_new = Network(args.init_channels,
                                        args.n_classes,
                                        args.layers,
                                        criterion,
                                        primitives,
                                        steps=args.nodes)
                    model_new = model_new.cuda()

                    optimizer_new = torch.optim.SGD(
                        model_new.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

                    architect_new = Architect(model_new, args)

                    analyser_new = Analyzer(args, model_new)

                    la_tracker = utils.EVLocalAvg(args.window,
                                                  args.report_freq_hessian,
                                                  args.epochs)

                    lr = utils.load_checkpoint(model_new, optimizer_new, None,
                                               architect_new, args.save,
                                               la_tracker, start_again_epoch,
                                               args.task_id)

                    args.weight_decay *= args.mul_factor
                    for param_group in optimizer_new.param_groups:
                        param_group['weight_decay'] = args.weight_decay

                    scheduler_new = CosineAnnealingLR(
                        optimizer_new,
                        float(args.epochs),
                        eta_min=args.learning_rate_min)

                    logging.info('(ADA) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(ADA) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    logging.info(
                        '(ADA) Adjusting L2 regularization to the new value: %f',
                        args.weight_decay)

                    genotype, valid_acc = train_epochs(args.epochs,
                                                       iteration + 1,
                                                       model=model_new,
                                                       optimizer=optimizer_new,
                                                       architect=architect_new,
                                                       scheduler=scheduler_new,
                                                       analyser=analyser_new,
                                                       start_epoch=start_epoch)
                    args.early_stop = 0
                    break

        return genotype, valid_acc
Ejemplo n.º 5
0
def compute_landscape(base_dir, primitives):
    checkpoint_path = 'checkpoint_%d_%d.pth.tar' % (args.task_id,
                                                    args.checkpoint_epoch - 1)
    weight_file = os.path.join(base_dir, checkpoint_path)
    #  alpha_file = os.path.join(base_dir, 'results_of_7q/alpha/49.txt')
    #  alpha_value = load_alpha(alpha_file)
    beta_decay_scheduler.step(args.checkpoint_epoch - 1)
    ckpt = torch.load(weight_file, map_location=lambda storage, loc: storage)
    #  print(ckpt.keys())
    state_dict = ckpt['state_dict']
    alpha_value = [ckpt['alphas_normal'], ckpt['alphas_reduce']]
    del ckpt

    # construct model & load weights and alpha
    print(">>> Constructing model & load weights and alphas")
    criterion = nn.CrossEntropyLoss()
    if not args.disable_cuda:
        criterion = criterion.cuda()

    model = Network(args.init_channels,
                    args.n_classes,
                    args.layers,
                    criterion,
                    primitives,
                    steps=args.nodes,
                    args=args)
    if not args.disable_cuda:
        model = model.cuda()

    model.load_state_dict(state_dict, strict=False)
    arch_params = model.arch_parameters()
    for i, alpha in enumerate(arch_params):
        alpha.data.copy_(alpha_value[i])

    # get data_loader
    print(">>> Constructing dataloader")
    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    if args.test_infer:
        print(
            ">>> First test whether the weight and alpha are loaded correctly")
        acc, loss = infer(valid_queue, model, criterion)
        print(acc, loss)

    landscape_dir = os.path.join(base_dir, 'vis_landscape')
    if not os.path.exists(landscape_dir):
        os.makedirs(landscape_dir)
    # obtain directions of pertubation
    print(">>> Obtaining direction of pertubation")
    direction_file = os.path.join(landscape_dir,
                                  'direction-%d.json' % args.task_id)
    d1, d2 = obtain_direction(model,
                              valid_queue,
                              direction_file=direction_file,
                              method='random',
                              norm_type='cellwise')

    # compute landscape
    print(">>> Computing landscape")
    x = np.linspace(args.xmin, args.xmax, num=args.xnum)
    y = np.linspace(args.xmin, args.xmax, num=args.xnum)
    losses = np.zeros([len(x), len(y)])
    accs = np.zeros([len(x), len(y)])
    for i, delta1 in enumerate(x):
        for j, delta2 in enumerate(y):
            arch_params[0].data.copy_(alpha_value[0] + d1[0] * delta1 +
                                      d2[0] * delta2)
            arch_params[1].data.copy_(alpha_value[1] + d1[1] * delta1 +
                                      d2[1] * delta2)
            acc, loss = infer(valid_queue, model, criterion)
            losses[i, j] = loss
            accs[i, j] = acc
            print('x,y/acc/loss: %f,%f/%f/%f' % (delta1, delta2, acc, loss))

    # save loss & acc
    results_file = os.path.join(landscape_dir, 'results-%d.csv' % args.task_id)
    title = ['X', 'Y'] + ['loss_%d' % i for i in range(losses.shape[1])
                          ] + ['acc_%d' % i for i in range(accs.shape[1])]
    with open(results_file, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(title)
        for i in range(losses.shape[0]):
            row = [x[i], y[i]] + losses[i, :].tolist() + accs[i, :].tolist()
            writer.writerow(row)

    return x, y, losses, accs