예제 #1
0
def get_all_trained_models_info(models_path, use_profiler=False, device='gpu'):
    print('Testing all models in: {}'.format(models_path))

    for model_name in sorted(os.listdir(models_path)):
        try:
            model_params = arcs.load_params(models_path, model_name, -1)
            train_time = model_params['total_time']
            num_epochs = model_params['epochs']
            architecture = model_params['architecture']
            print(model_name)
            task = model_params['task']
            print(task)
            net_type = model_params['network_type']
            print(net_type)

            top1_test = model_params['test_top1_acc']
            top1_train = model_params['train_top1_acc']
            top5_test = model_params['test_top5_acc']
            top5_train = model_params['train_top5_acc']

            print('Top1 Test accuracy: {}'.format(top1_test[-1]))
            print('Top5 Test accuracy: {}'.format(top5_test[-1]))
            print('\nTop1 Train accuracy: {}'.format(top1_train[-1]))
            print('Top5 Train accuracy: {}'.format(top5_train[-1]))

            print('Training time: {}, in {} epochs'.format(
                train_time, num_epochs))

            if use_profiler:
                model, _ = arcs.load_model(models_path, model_name, epoch=-1)
                model.to(device)
                input_size = model_params['input_size']

                if architecture == 'dsn':
                    total_ops, total_params = profile_dsn(
                        model, input_size, device)
                    print("#Ops (GOps): {}".format(total_ops))
                    print("#Params (mil): {}".format(total_params))

                else:
                    total_ops, total_params = profile(model, input_size,
                                                      device)
                    print("#Ops: %f GOps" % (total_ops / 1e9))
                    print("#Parameters: %f M" % (total_params / 1e6))

            print('------------------------')
        except:
            print('FAIL: {}'.format(model_name))
            continue
예제 #2
0
def train_sdns(models_path, networks, ic_only=False, device='cpu'):
    if ic_only:  # if we only train the ICs, we load a pre-trained CNN
        load_epoch = -1
    else:  # if we train both ICs and the orig network, we load an untrained CNN
        load_epoch = 0

    for sdn_name in networks:
        cnn_to_tune = sdn_name.replace('sdn', 'cnn')
        sdn_params = arcs.load_params(models_path, sdn_name)
        sdn_params = arcs.get_net_params(sdn_params['network_type'],
                                         sdn_params['task'])
        sdn_model, _ = af.cnn_to_sdn(
            models_path, cnn_to_tune, sdn_params,
            load_epoch)  # load the CNN and convert it to a SDN
        arcs.save_model(sdn_model, sdn_params, models_path, sdn_name,
                        epoch=0)  # save the resulting SDN
    train(models_path, networks, sdn=True, ic_only_sdn=ic_only, device=device)