def init_dataset(): if args.task == "detection": if not os.path.exists(args.data_cache): imdb_name = "voc_2007_trainval" imdbval_name = "voc_2007_test" imdb, roidb, ratio_list, ratio_index = combined_roidb( imdb_name, ['DATA_DIR', args.data_dir], sizes=args.train_size_file) train_dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, imdb.num_classes, imdb._image_index_temp, training=True) imdb_, roidb_, ratio_list_, ratio_index_ = combined_roidb( imdbval_name, ['DATA_DIR', args.data_dir], sizes=args.test_size_file, training=False) imdb_.competition_mode(on=True) test_dataset = roibatchLoader(roidb_, ratio_list_, ratio_index_, 1, imdb_.num_classes, imdb_._image_index_temp, training=False, normalize=False) with open(args.data_cache, 'wb') as f: pickle.dump(train_dataset, f, -1) pickle.dump(test_dataset, f, -1) else: with open(args.data_cache, 'rb') as f: train_dataset = pickle.load(f) test_dataset = pickle.load(f) else: if args.data_set == 'Mnist': train_transform, test_transform = get_data_transform('mnist') train_dataset = datasets.MNIST(args.data_dir, train=True, download=True, transform=train_transform) test_dataset = datasets.MNIST(args.data_dir, train=False, download=True, transform=test_transform) elif args.data_set == 'cifar10': train_transform, test_transform = get_data_transform('cifar') train_dataset = datasets.CIFAR10(args.data_dir, train=True, download=True, transform=train_transform) test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=True, transform=test_transform) elif args.data_set == "imagenet": train_transform, test_transform = get_data_transform('imagenet') train_dataset = datasets.ImageNet(args.data_dir, split='train', download=False, transform=train_transform) test_dataset = datasets.ImageNet(args.data_dir, split='val', download=False, transform=test_transform) elif args.data_set == 'emnist': test_dataset = datasets.EMNIST(args.data_dir, split='balanced', train=False, download=True, transform=transforms.ToTensor()) train_dataset = datasets.EMNIST(args.data_dir, split='balanced', train=True, download=True, transform=transforms.ToTensor()) elif args.data_set == 'femnist': from utils.femnist import FEMNIST train_transform, test_transform = get_data_transform('mnist') train_dataset = FEMNIST(args.data_dir, train=True, transform=train_transform) test_dataset = FEMNIST(args.data_dir, train=False, transform=test_transform) elif args.data_set == 'openImg': from utils.openimage import OpenImage train_transform, test_transform = get_data_transform('openImg') train_dataset = OpenImage(args.data_dir, dataset='train', transform=train_transform) test_dataset = OpenImage(args.data_dir, dataset='test', transform=test_transform) elif args.data_set == 'blog': train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) test_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) elif args.data_set == 'stackoverflow': from utils.stackoverflow import stackoverflow train_dataset = stackoverflow(args.data_dir, train=True) test_dataset = stackoverflow(args.data_dir, train=False) elif args.data_set == 'yelp': import utils.dataloaders as fl_loader train_dataset = fl_loader.TextSentimentDataset( args.data_dir, train=True, tokenizer=tokenizer, max_len=args.clf_block_size) test_dataset = fl_loader.TextSentimentDataset( args.data_dir, train=False, tokenizer=tokenizer, max_len=args.clf_block_size) elif args.data_set == 'google_speech': bkg = '_background_noise_' data_aug_transform = transforms.Compose([ ChangeAmplitude(), ChangeSpeedAndPitchAudio(), FixAudioLength(), ToSTFT(), StretchAudioOnSTFT(), TimeshiftAudioOnSTFT(), FixSTFTDimension() ]) bg_dataset = BackgroundNoiseDataset( os.path.join(args.data_dir, bkg), data_aug_transform) add_bg_noise = AddBackgroundNoiseOnSTFT(bg_dataset) train_feature_transform = transforms.Compose([ ToMelSpectrogramFromSTFT(n_mels=32), DeleteSTFT(), ToTensor('mel_spectrogram', 'input') ]) train_dataset = SPEECH(args.data_dir, dataset='train', transform=transforms.Compose([ LoadAudio(), data_aug_transform, add_bg_noise, train_feature_transform ])) valid_feature_transform = transforms.Compose([ ToMelSpectrogram(n_mels=32), ToTensor('mel_spectrogram', 'input') ]) test_dataset = SPEECH(args.data_dir, dataset='test', transform=transforms.Compose([ LoadAudio(), FixAudioLength(), valid_feature_transform ])) elif args.data_set == 'common_voice': from utils.voice_data_loader import SpectrogramDataset train_dataset = SpectrogramDataset( audio_conf=model.audio_conf, manifest_filepath=args.train_manifest, labels=model.labels, normalize=True, speed_volume_perturb=args.speed_volume_perturb, spec_augment=args.spec_augment, data_mapfile=args.data_mapfile) test_dataset = SpectrogramDataset( audio_conf=model.audio_conf, manifest_filepath=args.test_manifest, labels=model.labels, normalize=True, speed_volume_perturb=False, spec_augment=False) else: print('DataSet must be {}!'.format([ 'Mnist', 'Cifar', 'openImg', 'blog', 'stackoverflow', 'speech', 'yelp' ])) sys.exit(-1) return train_dataset, test_dataset
os.environ['MASTER_ADDR'] = args.ps_ip os.environ['MASTER_PORT'] = args.ps_port dist.init_process_group(backend, rank=rank, world_size=size) fn(model, test_data, queue, param_q, stop_signal) if __name__ == "__main__": # 随机数设置 - Random manual_seed = random.randint(1, 10000) random.seed(manual_seed) torch.manual_seed(manual_seed) if args.model == 'MnistCNN': model = MnistCNN() train_t, test_t = get_data_transform('mnist') test_dataset = datasets.MNIST(args.data_dir, train=False, download=False, transform=test_t) elif args.model == 'AlexNet': model = AlexNetForCIFAR() train_t, test_t = get_data_transform('cifar') test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=False, transform=test_t) else: print('Model must be {} or {}!'.format('MnistCNN', 'AlexNet')) sys.exit(-1)
def init_dataset(): global tokenizer outputClass = {'Mnist': 10, 'cifar10': 10, "imagenet": 1000, 'emnist': 47, 'openImg': 596, 'google_speech': 35, 'femnist': 62, 'yelp': 5 } logging.info("====Initialize the model") if args.task == 'nlp': # we should train from scratch config = AutoConfig.from_pretrained(os.path.join(args.data_dir, 'albert-base-v2-config.json')) model = AutoModelWithLMHead.from_config(config) elif args.task == 'text_clf': config = AutoConfig.from_pretrained(os.path.join(args.data_dir, 'albert-base-v2-config.json')) config.num_labels = outputClass[args.data_set] # config.output_attentions = False # config.output_hidden_states = False from transformers import AlbertForSequenceClassification model = AlbertForSequenceClassification(config) elif args.task == 'tag-one-sample': # Load LR model for tag prediction model = LogisticRegression(args.vocab_token_size, args.vocab_tag_size) elif args.task == 'speech': if args.model == 'mobilenet': from utils.resnet_speech import mobilenet_v2 model = mobilenet_v2(num_classes=outputClass[args.data_set], inchannels=1) elif args.model == "resnet18": from utils.resnet_speech import resnet18 model = resnet18(num_classes=outputClass[args.data_set], in_channels=1) elif args.model == "resnet34": from utils.resnet_speech import resnet34 model = resnet34(num_classes=outputClass[args.data_set], in_channels=1) elif args.model == "resnet50": from utils.resnet_speech import resnet50 model = resnet50(num_classes=outputClass[args.data_set], in_channels=1) elif args.model == "resnet101": from utils.resnet_speech import resnet101 model = resnet101(num_classes=outputClass[args.data_set], in_channels=1) elif args.model == "resnet152": from utils.resnet_speech import resnet152 model = resnet152(num_classes=outputClass[args.data_set], in_channels=1) else: # Should not reach here logging.info('Model must be resnet or mobilenet') sys.exit(-1) elif args.task == 'voice': from utils.voice_model import DeepSpeech, supported_rnns # Initialise new model training with open(args.labels_path) as label_file: labels = json.load(label_file) audio_conf = dict(sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, noise_dir=args.noise_dir, noise_prob=args.noise_prob, noise_levels=(args.noise_min, args.noise_max)) model = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, labels=labels, rnn_type=supported_rnns[args.rnn_type.lower()], audio_conf=audio_conf, bidirectional=args.bidirectional) else: model = tormodels.__dict__[args.model](num_classes=outputClass[args.data_set]) if args.load_model: try: with open(modelPath, 'rb') as fin: model = pickle.load(fin) logging.info("====Load model successfully\n") except Exception as e: logging.info("====Error: Failed to load model due to {}\n".format(str(e))) sys.exit(-1) train_dataset, test_dataset = [], [] # Load data if the machine acts as clients if args.this_rank != 0: if args.data_set == 'Mnist': train_transform, test_transform = get_data_transform('mnist') train_dataset = datasets.MNIST(args.data_dir, train=True, download=True, transform=train_transform) test_dataset = datasets.MNIST(args.data_dir, train=False, download=True, transform=test_transform) elif args.data_set == 'cifar10': train_transform, test_transform = get_data_transform('cifar') train_dataset = datasets.CIFAR10(args.data_dir, train=True, download=True, transform=train_transform) test_dataset = datasets.CIFAR10(args.data_dir, train=False, download=True, transform=test_transform) elif args.data_set == "imagenet": train_transform, test_transform = get_data_transform('imagenet') train_dataset = datasets.ImageNet(args.data_dir, split='train', download=False, transform=train_transform) test_dataset = datasets.ImageNet(args.data_dir, split='val', download=False, transform=test_transform) elif args.data_set == 'emnist': test_dataset = datasets.EMNIST(args.data_dir, split='balanced', train=False, download=True, transform=transforms.ToTensor()) train_dataset = datasets.EMNIST(args.data_dir, split='balanced', train=True, download=True, transform=transforms.ToTensor()) elif args.data_set == 'femnist': from utils.femnist import FEMNIST train_transform, test_transform = get_data_transform('mnist') train_dataset = FEMNIST(args.data_dir, train=True, transform=train_transform) test_dataset = FEMNIST(args.data_dir, train=False, transform=test_transform) elif args.data_set == 'openImg': from utils.openImg import OPENIMG transformer_ns = 'openImg' if args.model != 'inception_v3' else 'openImgInception' train_transform, test_transform = get_data_transform(transformer_ns) train_dataset = OPENIMG(args.data_dir, train=True, transform=train_transform) test_dataset = OPENIMG(args.data_dir, train=False, transform=test_transform) elif args.data_set == 'blog': train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) test_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) elif args.data_set == 'stackoverflow': from utils.stackoverflow import stackoverflow train_dataset = stackoverflow(args.data_dir, train=True) test_dataset = stackoverflow(args.data_dir, train=False) elif args.data_set == 'yelp': import utils.dataloaders as fl_loader train_dataset = fl_loader.TextSentimentDataset(args.data_dir, train=True, tokenizer=tokenizer, max_len=args.clf_block_size) test_dataset = fl_loader.TextSentimentDataset(args.data_dir, train=False, tokenizer=tokenizer, max_len=args.clf_block_size) elif args.data_set == 'google_speech': bkg = '_background_noise_' data_aug_transform = transforms.Compose([ChangeAmplitude(), ChangeSpeedAndPitchAudio(), FixAudioLength(), ToSTFT(), StretchAudioOnSTFT(), TimeshiftAudioOnSTFT(), FixSTFTDimension()]) bg_dataset = BackgroundNoiseDataset(os.path.join(args.data_dir, bkg), data_aug_transform) add_bg_noise = AddBackgroundNoiseOnSTFT(bg_dataset) train_feature_transform = transforms.Compose([ToMelSpectrogramFromSTFT(n_mels=32), DeleteSTFT(), ToTensor('mel_spectrogram', 'input')]) train_dataset = SPEECH(args.data_dir, train= True, transform=transforms.Compose([LoadAudio(), data_aug_transform, add_bg_noise, train_feature_transform])) valid_feature_transform = transforms.Compose([ToMelSpectrogram(n_mels=32), ToTensor('mel_spectrogram', 'input')]) test_dataset = SPEECH(args.data_dir, train=False, transform=transforms.Compose([LoadAudio(), FixAudioLength(), valid_feature_transform])) elif args.data_set == 'common_voice': from utils.voice_data_loader import SpectrogramDataset train_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.train_manifest, labels=model.labels, normalize=True, speed_volume_perturb=args.speed_volume_perturb, spec_augment=args.spec_augment, data_mapfile=args.data_mapfile) test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=args.test_manifest, labels=model.labels, normalize=True, speed_volume_perturb=False, spec_augment=False) else: print('DataSet must be {}!'.format(['Mnist', 'Cifar', 'openImg', 'blog', 'stackoverflow', 'speech', 'yelp'])) sys.exit(-1) return model, train_dataset, test_dataset