示例#1
0
def main(primitives):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model_init = Network(args.init_channels,
                         args.n_classes,
                         args.layers,
                         criterion,
                         primitives,
                         steps=args.nodes)
    model_init = model_init.cuda()
    #logging.info("param size = %fMB", utils.count_parameters_in_MB(model_init))
    optimizer_init = torch.optim.SGD(model_init.parameters(),
                                     args.learning_rate,
                                     momentum=args.momentum,
                                     weight_decay=args.weight_decay)
    architect_init = Architect(model_init, args)

    scheduler_init = CosineAnnealingLR(optimizer_init,
                                       float(args.epochs),
                                       eta_min=args.learning_rate_min)

    analyser_init = Analyzer(args, model_init)
    la_tracker = utils.EVLocalAvg(args.window, args.report_freq_hessian,
                                  args.epochs)

    if args.resume:
        if os.path.isfile(args.resume_file):
            print("=> loading checkpoint '{}'".format(args.resume_file))
            checkpoint = torch.load(args.resume_file)
            start_epoch = 27
            print('start_epoch', start_epoch)
            model_init.load_state_dict(checkpoint['state_dict'])
            model_init.alphas_normal.data = checkpoint['alphas_normal']
            model_init.alphas_reduce.data = checkpoint['alphas_reduce']
            model_init = model_init.cuda()
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(model_init))
            optimizer_init.load_state_dict(checkpoint['optimizer'])
            architect_init.optimizer.load_state_dict(
                checkpoint['arch_optimizer'])
            scheduler_init = CosineAnnealingLR(optimizer_init,
                                               float(args.epochs),
                                               eta_min=args.learning_rate_min)
            analyser_init = Analyzer(args, model_init)
            la_tracker = utils.EVLocalAvg(args.window,
                                          args.report_freq_hessian,
                                          args.epochs)
            la_tracker.ev = checkpoint['ev']
            la_tracker.ev_local_avg = checkpoint['ev_local_avg']
            la_tracker.genotypes = checkpoint['genotypes']
            la_tracker.la_epochs = checkpoint['la_epochs']
            la_tracker.la_start_idx = checkpoint['la_start_idx']
            la_tracker.la_end_idx = checkpoint['la_end_idx']
            lr = checkpoint['lr']

    train_queue, valid_queue, train_transform, valid_transform = helper.get_train_val_loaders(
    )

    errors_dict = {
        'train_acc': [],
        'train_loss': [],
        'valid_acc': [],
        'valid_loss': []
    }

    #for epoch in range(args.epochs):
    def train_epochs(epochs_to_train,
                     iteration,
                     args=args,
                     model=model_init,
                     optimizer=optimizer_init,
                     scheduler=scheduler_init,
                     train_queue=train_queue,
                     valid_queue=valid_queue,
                     train_transform=train_transform,
                     valid_transform=valid_transform,
                     architect=architect_init,
                     criterion=criterion,
                     primitives=primitives,
                     analyser=analyser_init,
                     la_tracker=la_tracker,
                     errors_dict=errors_dict,
                     start_epoch=-1):

        logging.info('STARTING ITERATION: %d', iteration)
        logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1)

        la_tracker.stop_search = False

        if epochs_to_train - start_epoch - 1 <= 0:
            return model.genotype(), -1
        for epoch in range(start_epoch + 1, epochs_to_train):
            # set the epoch to the right one
            #epoch += args.epochs - epochs_to_train

            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            if args.drop_path_prob != 0:
                model.drop_path_prob = args.drop_path_prob * epoch / (
                    args.epochs - 1)
                train_transform.transforms[
                    -1].cutout_prob = args.cutout_prob * epoch / (args.epochs -
                                                                  1)
                logging.info('epoch %d lr %e drop_prob %e cutout_prob %e',
                             epoch, lr, model.drop_path_prob,
                             train_transform.transforms[-1].cutout_prob)
            else:
                logging.info('epoch %d lr %e', epoch, lr)

            # training
            train_acc, train_obj = train(epoch, primitives, train_queue,
                                         valid_queue, model, architect,
                                         criterion, optimizer, lr, analyser,
                                         la_tracker, iteration)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            # update the errors dictionary
            errors_dict['train_acc'].append(100 - train_acc)
            errors_dict['train_loss'].append(train_obj)
            errors_dict['valid_acc'].append(100 - valid_acc)
            errors_dict['valid_loss'].append(valid_obj)

            genotype = model.genotype()

            logging.info('genotype = %s', genotype)

            print(F.softmax(model.alphas_normal, dim=-1))
            print(F.softmax(model.alphas_reduce, dim=-1))

            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'alphas_normal': model.alphas_normal.data,
                'alphas_reduce': model.alphas_reduce.data,
                'arch_optimizer': architect.optimizer.state_dict(),
                'lr': lr,
                'ev': la_tracker.ev,
                'ev_local_avg': la_tracker.ev_local_avg,
                'genotypes': la_tracker.genotypes,
                'la_epochs': la_tracker.la_epochs,
                'la_start_idx': la_tracker.la_start_idx,
                'la_end_idx': la_tracker.la_end_idx,
                #'scheduler': scheduler.state_dict(),
            }

            utils.save_checkpoint(state, False, args.save, epoch, args.task_id)

            if not args.compute_hessian:
                ev = -1
            else:
                ev = la_tracker.ev[-1]
            params = {
                'iteration': iteration,
                'epoch': epoch,
                'wd': args.weight_decay,
                'ev': ev,
            }

            schedule_of_params.append(params)

            # limit the number of iterations based on the maximum regularization
            # value predefined by the user
            final_iteration = round(
                np.log(args.max_weight_decay) / np.log(args.weight_decay),
                1) == 1.  ##lr decay到一定程度就停止

            if la_tracker.stop_search and not final_iteration:
                if args.early_stop == 1:
                    # set the following to the values they had at stop_epoch
                    errors_dict['valid_acc'] = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    genotype = la_tracker.stop_genotype
                    valid_acc = 100 - errors_dict['valid_acc'][
                        la_tracker.stop_epoch]
                    logging.info(
                        'Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('Validation accuracy at stop epoch: %f',
                                 valid_acc)
                    logging.info('Genotype at stop epoch: %s', genotype)
                    break

                elif args.early_stop == 2:
                    # simulate early stopping and continue search afterwards
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    logging.info(
                        '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('(SIM) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(SIM) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    with open(
                            os.path.join(args.save,
                                         'arch_early_{}'.format(args.task_id)),
                            'w') as file:
                        file.write(str(simulated_genotype))

                    utils.write_yaml_results(args,
                                             'early_' + args.results_file_arch,
                                             str(simulated_genotype))
                    utils.write_yaml_results(args, 'early_stop_epochs',
                                             la_tracker.stop_epoch)

                    args.early_stop = 0

                elif args.early_stop == 3:
                    # adjust regularization
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    stop_epoch = la_tracker.stop_epoch
                    start_again_epoch = stop_epoch - args.extra_rollback_epochs
                    logging.info(
                        '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)',
                        stop_epoch, epoch)
                    logging.info('(ADA) Rolling back to epoch %d',
                                 start_again_epoch)
                    logging.info(
                        '(ADA) Restoring model parameters and continuing for %d epochs',
                        epochs_to_train - start_again_epoch - 1)

                    if iteration == 1:
                        logging.info(
                            '(ADA) Saving the architecture at the early stop epoch and '
                            'continuing with the adaptive regularization strategy'
                        )
                        utils.write_yaml_results(
                            args, 'early_' + args.results_file_arch,
                            str(simulated_genotype))

                    del model
                    del architect
                    del optimizer
                    del scheduler
                    del analyser

                    model_new = Network(args.init_channels,
                                        args.n_classes,
                                        args.layers,
                                        criterion,
                                        primitives,
                                        steps=args.nodes)
                    model_new = model_new.cuda()

                    optimizer_new = torch.optim.SGD(
                        model_new.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

                    architect_new = Architect(model_new, args)

                    analyser_new = Analyzer(args, model_new)

                    la_tracker = utils.EVLocalAvg(args.window,
                                                  args.report_freq_hessian,
                                                  args.epochs)

                    lr = utils.load_checkpoint(model_new, optimizer_new, None,
                                               architect_new, args.save,
                                               la_tracker, start_again_epoch,
                                               args.task_id)

                    args.weight_decay *= args.mul_factor
                    for param_group in optimizer_new.param_groups:
                        param_group['weight_decay'] = args.weight_decay

                    scheduler_new = CosineAnnealingLR(
                        optimizer_new,
                        float(args.epochs),
                        eta_min=args.learning_rate_min)

                    logging.info('(ADA) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(ADA) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    logging.info(
                        '(ADA) Adjusting L2 regularization to the new value: %f',
                        args.weight_decay)

                    genotype, valid_acc = train_epochs(args.epochs,
                                                       iteration + 1,
                                                       model=model_new,
                                                       optimizer=optimizer_new,
                                                       architect=architect_new,
                                                       scheduler=scheduler_new,
                                                       analyser=analyser_new,
                                                       start_epoch=start_epoch)
                    args.early_stop = 0
                    break

        return genotype, valid_acc

    # call train_epochs recursively
    genotype, valid_acc = train_epochs(args.epochs, 1)

    with codecs.open(os.path.join(args.save,
                                  'errors_{}.json'.format(args.task_id)),
                     'w',
                     encoding='utf-8') as file:
        json.dump(errors_dict, file, separators=(',', ':'))

    with open(os.path.join(args.save, 'arch_{}'.format(args.task_id)),
              'w') as file:
        file.write(str(genotype))

    utils.write_yaml_results(args, args.results_file_arch, str(genotype))
    utils.write_yaml_results(args, args.results_file_perf, 100 - valid_acc)

    with open(
            os.path.join(args.save, 'schedule_{}.pickle'.format(args.task_id)),
            'ab') as file:
        pickle.dump(schedule_of_params, file, pickle.HIGHEST_PROTOCOL)
示例#2
0
    def train_epochs(epochs_to_train,
                     iteration,
                     args=args,
                     model=model_init,
                     optimizer=optimizer_init,
                     scheduler=scheduler_init,
                     train_queue=train_queue,
                     valid_queue=valid_queue,
                     train_transform=train_transform,
                     valid_transform=valid_transform,
                     architect=architect_init,
                     criterion=criterion,
                     primitives=primitives,
                     analyser=analyser_init,
                     la_tracker=la_tracker,
                     errors_dict=errors_dict,
                     start_epoch=-1):

        logging.info('STARTING ITERATION: %d', iteration)
        logging.info('EPOCHS TO TRAIN: %d', epochs_to_train - start_epoch - 1)

        la_tracker.stop_search = False

        if epochs_to_train - start_epoch - 1 <= 0:
            return model.genotype(), -1
        for epoch in range(start_epoch + 1, epochs_to_train):
            # set the epoch to the right one
            #epoch += args.epochs - epochs_to_train

            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            if args.drop_path_prob != 0:
                model.drop_path_prob = args.drop_path_prob * epoch / (
                    args.epochs - 1)
                train_transform.transforms[
                    -1].cutout_prob = args.cutout_prob * epoch / (args.epochs -
                                                                  1)
                logging.info('epoch %d lr %e drop_prob %e cutout_prob %e',
                             epoch, lr, model.drop_path_prob,
                             train_transform.transforms[-1].cutout_prob)
            else:
                logging.info('epoch %d lr %e', epoch, lr)

            # training
            train_acc, train_obj = train(epoch, primitives, train_queue,
                                         valid_queue, model, architect,
                                         criterion, optimizer, lr, analyser,
                                         la_tracker, iteration)
            logging.info('train_acc %f', train_acc)

            # validation
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

            # update the errors dictionary
            errors_dict['train_acc'].append(100 - train_acc)
            errors_dict['train_loss'].append(train_obj)
            errors_dict['valid_acc'].append(100 - valid_acc)
            errors_dict['valid_loss'].append(valid_obj)

            genotype = model.genotype()

            logging.info('genotype = %s', genotype)

            print(F.softmax(model.alphas_normal, dim=-1))
            print(F.softmax(model.alphas_reduce, dim=-1))

            state = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'alphas_normal': model.alphas_normal.data,
                'alphas_reduce': model.alphas_reduce.data,
                'arch_optimizer': architect.optimizer.state_dict(),
                'lr': lr,
                'ev': la_tracker.ev,
                'ev_local_avg': la_tracker.ev_local_avg,
                'genotypes': la_tracker.genotypes,
                'la_epochs': la_tracker.la_epochs,
                'la_start_idx': la_tracker.la_start_idx,
                'la_end_idx': la_tracker.la_end_idx,
                #'scheduler': scheduler.state_dict(),
            }

            utils.save_checkpoint(state, False, args.save, epoch, args.task_id)

            if not args.compute_hessian:
                ev = -1
            else:
                ev = la_tracker.ev[-1]
            params = {
                'iteration': iteration,
                'epoch': epoch,
                'wd': args.weight_decay,
                'ev': ev,
            }

            schedule_of_params.append(params)

            # limit the number of iterations based on the maximum regularization
            # value predefined by the user
            final_iteration = round(
                np.log(args.max_weight_decay) / np.log(args.weight_decay),
                1) == 1.  ##lr decay到一定程度就停止

            if la_tracker.stop_search and not final_iteration:
                if args.early_stop == 1:
                    # set the following to the values they had at stop_epoch
                    errors_dict['valid_acc'] = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    genotype = la_tracker.stop_genotype
                    valid_acc = 100 - errors_dict['valid_acc'][
                        la_tracker.stop_epoch]
                    logging.info(
                        'Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('Validation accuracy at stop epoch: %f',
                                 valid_acc)
                    logging.info('Genotype at stop epoch: %s', genotype)
                    break

                elif args.early_stop == 2:
                    # simulate early stopping and continue search afterwards
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    logging.info(
                        '(SIM) Decided to stop the search at epoch %d (Current epoch: %d)',
                        la_tracker.stop_epoch, epoch)
                    logging.info('(SIM) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(SIM) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    with open(
                            os.path.join(args.save,
                                         'arch_early_{}'.format(args.task_id)),
                            'w') as file:
                        file.write(str(simulated_genotype))

                    utils.write_yaml_results(args,
                                             'early_' + args.results_file_arch,
                                             str(simulated_genotype))
                    utils.write_yaml_results(args, 'early_stop_epochs',
                                             la_tracker.stop_epoch)

                    args.early_stop = 0

                elif args.early_stop == 3:
                    # adjust regularization
                    simulated_errors_dict = errors_dict[
                        'valid_acc'][:la_tracker.stop_epoch + 1]
                    simulated_genotype = la_tracker.stop_genotype
                    simulated_valid_acc = 100 - simulated_errors_dict[
                        la_tracker.stop_epoch]
                    stop_epoch = la_tracker.stop_epoch
                    start_again_epoch = stop_epoch - args.extra_rollback_epochs
                    logging.info(
                        '(ADA) Decided to increase regularization at epoch %d (Current epoch: %d)',
                        stop_epoch, epoch)
                    logging.info('(ADA) Rolling back to epoch %d',
                                 start_again_epoch)
                    logging.info(
                        '(ADA) Restoring model parameters and continuing for %d epochs',
                        epochs_to_train - start_again_epoch - 1)

                    if iteration == 1:
                        logging.info(
                            '(ADA) Saving the architecture at the early stop epoch and '
                            'continuing with the adaptive regularization strategy'
                        )
                        utils.write_yaml_results(
                            args, 'early_' + args.results_file_arch,
                            str(simulated_genotype))

                    del model
                    del architect
                    del optimizer
                    del scheduler
                    del analyser

                    model_new = Network(args.init_channels,
                                        args.n_classes,
                                        args.layers,
                                        criterion,
                                        primitives,
                                        steps=args.nodes)
                    model_new = model_new.cuda()

                    optimizer_new = torch.optim.SGD(
                        model_new.parameters(),
                        args.learning_rate,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

                    architect_new = Architect(model_new, args)

                    analyser_new = Analyzer(args, model_new)

                    la_tracker = utils.EVLocalAvg(args.window,
                                                  args.report_freq_hessian,
                                                  args.epochs)

                    lr = utils.load_checkpoint(model_new, optimizer_new, None,
                                               architect_new, args.save,
                                               la_tracker, start_again_epoch,
                                               args.task_id)

                    args.weight_decay *= args.mul_factor
                    for param_group in optimizer_new.param_groups:
                        param_group['weight_decay'] = args.weight_decay

                    scheduler_new = CosineAnnealingLR(
                        optimizer_new,
                        float(args.epochs),
                        eta_min=args.learning_rate_min)

                    logging.info('(ADA) Validation accuracy at stop epoch: %f',
                                 simulated_valid_acc)
                    logging.info('(ADA) Genotype at stop epoch: %s',
                                 simulated_genotype)

                    logging.info(
                        '(ADA) Adjusting L2 regularization to the new value: %f',
                        args.weight_decay)

                    genotype, valid_acc = train_epochs(args.epochs,
                                                       iteration + 1,
                                                       model=model_new,
                                                       optimizer=optimizer_new,
                                                       architect=architect_new,
                                                       scheduler=scheduler_new,
                                                       analyser=analyser_new,
                                                       start_epoch=start_epoch)
                    args.early_stop = 0
                    break

        return genotype, valid_acc