def train_autoencoder(device, args): # model definition model = FeatureExtractor() model.to(device) # data definition all_chunks = [] # concatenate all chunk files # note that it is independent of the # class of each chunk sinc we are creating # a generative dataset for label in filesystem.listdir_complete(filesystem.train_audio_chunks_dir): chunks = filesystem.listdir_complete(label) all_chunks = all_chunks + chunks train_chunks, eval_chunks = train_test_split(all_chunks, test_size=args.eval_size) # transforms and dataset trf = normalize train_dataset = GenerativeDataset(train_chunks, transforms=trf) eval_dataset = GenerativeDataset(eval_chunks, transforms=trf) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=None,pin_memory=True) eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=True, num_workers=4, collate_fn=None,pin_memory=True) # main loop optimizer = optim.Adam(model.parameters(), lr=args.lr) loss_criterion = SoftDTW(use_cuda=True, gamma=0.1) train_count = 0 eval_count = 0 for epoch in range(args.n_epochs): print('Epoch:', epoch, '/', args.n_epochs) train_count = train_step(model, train_dataloader, optimizer, loss_criterion, args.verbose_epochs, device, train_count) eval_count = eval_step(model, eval_dataloader, loss_criterion, args.verbose_epochs, device, eval_count) torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model_checkpoint.pt'))
class Trainer(object): def __init__(self, src_domain, tgt_domain): self.num_epoch = 10 self.gamma = 1.0 print('construct dataset and dataloader...') train_dataset = TrainDataset(src_domain, tgt_domain) self.NEG_NUM = train_dataset.NEG_NUM self.input_dim = train_dataset.sample_dim self.train_loader = DataLoader(train_dataset, batch_size=32) print('Done!') self.feature_extractor = FeatureExtractor(self.input_dim) self.optimizer = optim.SGD(self.feature_extractor.parameters(), lr=0.1, momentum=0.9) def train(self): for i in range(self.num_epoch): self.train_one_epoch(i) def train_one_epoch(self, epoch_ind): loss_item = 0 for iter, (src_pos, tgt_pos, tgt_negs) in enumerate(self.train_loader): self.optimizer.zero_grad() src_pos_feature = self.feature_extractor(src_pos) tgt_pos_feature = self.feature_extractor(tgt_pos) tgt_negs_features = self.feature_extractor( tgt_negs.reshape(-1, self.input_dim)) feature_dim = src_pos_feature.size()[1] tgt_negs_features = tgt_negs_features.reshape( -1, self.NEG_NUM, feature_dim) pos_sim = cosine_similarity(src_pos_feature, tgt_pos_feature) src_repeated_feature = src_pos_feature.unsqueeze(1).repeat( 1, self.NEG_NUM, 1) neg_sims = cosine_similarity(src_repeated_feature, tgt_negs_features, dim=2) all_sims = torch.cat((pos_sim.unsqueeze(1), neg_sims), dim=1) PDQ = softmax(all_sims * self.gamma, dim=1) # neg_prob_sum = torch.sum(PDQ[:, 1:], 1) # prediction = torch.cat((PDQ[:, 0].unsqueeze(1), neg_prob_sum.unsqueeze(1)), dim=1) # batchsize = src_pos_feature.size()[0] # target = torch.zeros(batchsize).long() # 第一列是正解 # loss = nll_loss(prediction, target) loss = -PDQ[:, 0].log().mean() loss.backward() self.optimizer.step() loss_item += loss.item() print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format( epoch_ind, iter, len(self.train_loader), loss.item()))
batch_size=1024, shuffle=True, drop_last=True) # -------------------------------------- Training Stage ------------------------------------------- # precision = 1e-8 feature_extractor = FeatureExtractor().cuda() label_predictor = LabelPredictor().cuda() domain_classifier = DomainClassifier().cuda() class_criterion = nn.CrossEntropyLoss() domain_criterion = nn.CrossEntropyLoss() optimizer_F = optim.Adam(feature_extractor.parameters()) optimizer_C = optim.Adam(label_predictor.parameters()) optimizer_D = optim.Adam(label_predictor.parameters()) scheduler_F = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_F, mode='min', factor=0.1, patience=8, verbose=True, eps=precision) scheduler_C = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_C, mode='min', factor=0.1, patience=8, verbose=True, eps=precision)
def main(args): np.random.seed(0) torch.manual_seed(0) with open('config.yaml', 'r') as file: stream = file.read() config_dict = yaml.safe_load(stream) config = mapper(**config_dict) disc_model = Discriminator(input_shape=(config.data.channels, config.data.hr_height, config.data.hr_width)) gen_model = GeneratorResNet() feature_extractor_model = FeatureExtractor() plt.ion() if config.distributed: disc_model.to(device) disc_model = nn.parallel.DistributedDataParallel(disc_model) gen_model.to(device) gen_model = nn.parallel.DistributedDataParallel(gen_model) feature_extractor_model.to(device) feature_extractor_model = nn.parallel.DistributedDataParallel( feature_extractor_model) elif config.gpu: # disc_model = nn.DataParallel(disc_model).to(device) # gen_model = nn.DataParallel(gen_model).to(device) # feature_extractor_model = nn.DataParallel(feature_extractor_model).to(device) disc_model = disc_model.to(device) gen_model = gen_model.to(device) feature_extractor_model = feature_extractor_model.to(device) else: return train_dataset = ImageDataset(config.data.path, hr_shape=(config.data.hr_height, config.data.hr_width), lr_shape=(config.data.lr_height, config.data.lr_width)) test_dataset = ImageDataset(config.data.path, hr_shape=(config.data.hr_height, config.data.hr_width), lr_shape=(config.data.lr_height, config.data.lr_width)) if config.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=config.data.shuffle, num_workers=config.data.workers, pin_memory=config.data.pin_memory, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.data.batch_size, shuffle=config.data.shuffle, num_workers=config.data.workers, pin_memory=config.data.pin_memory) if args.train: # trainer settings trainer = GANTrainer(config.train, train_loader, (disc_model, gen_model, feature_extractor_model)) criterion = nn.MSELoss().to(device) disc_optimizer = torch.optim.Adam(disc_model.parameters(), config.train.hyperparameters.lr) gen_optimizer = torch.optim.Adam(gen_model.parameters(), config.train.hyperparameters.lr) fe_optimizer = torch.optim.Adam(feature_extractor_model.parameters(), config.train.hyperparameters.lr) trainer.setCriterion(criterion) trainer.setDiscOptimizer(disc_optimizer) trainer.setGenOptimizer(gen_optimizer) trainer.setFEOptimizer(fe_optimizer) # evaluator settings evaluator = GANEvaluator( config.evaluate, val_loader, (disc_model, gen_model, feature_extractor_model)) # optimizer = torch.optim.Adam(disc_model.parameters(), lr=config.evaluate.hyperparameters.lr, # weight_decay=config.evaluate.hyperparameters.weight_decay) evaluator.setCriterion(criterion) if args.test: pass # Turn on benchmark if the input sizes don't vary # It is used to find best way to run models on your machine cudnn.benchmark = True start_epoch = 0 best_precision = 0 # optionally resume from a checkpoint if config.train.resume: [start_epoch, best_precision] = trainer.load_saved_checkpoint(checkpoint=None) # change value to test.hyperparameters on testing for epoch in range(start_epoch, config.train.hyperparameters.total_epochs): if config.distributed: train_sampler.set_epoch(epoch) if args.train: trainer.adjust_learning_rate(epoch) trainer.train(epoch) prec1 = evaluator.evaluate(epoch) if args.test: pass # remember best prec@1 and save checkpoint if args.train: is_best = prec1 > best_precision best_precision = max(prec1, best_precision) trainer.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': disc_model.state_dict(), 'best_precision': best_precision, 'optimizer': optimizer.state_dict(), }, is_best, checkpoint=None)