Ejemplo n.º 1
0
def train_autoencoder(device, args):
    # model definition
    model = FeatureExtractor()
    model.to(device)
    # data definition
    all_chunks = []
    # concatenate all chunk files
    # note that it is independent of the
    # class of each chunk sinc we are creating
    # a generative dataset
    for label in filesystem.listdir_complete(filesystem.train_audio_chunks_dir):
        chunks = filesystem.listdir_complete(label)
        all_chunks = all_chunks + chunks
    train_chunks, eval_chunks = train_test_split(all_chunks, test_size=args.eval_size)
    # transforms and dataset
    trf = normalize

    train_dataset = GenerativeDataset(train_chunks, transforms=trf)
    eval_dataset = GenerativeDataset(eval_chunks, transforms=trf)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                num_workers=4, collate_fn=None,pin_memory=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=True,
                                num_workers=4, collate_fn=None,pin_memory=True)

    # main loop
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    loss_criterion = SoftDTW(use_cuda=True, gamma=0.1)
    train_count = 0
    eval_count = 0
    for epoch in range(args.n_epochs):
        print('Epoch:', epoch, '/', args.n_epochs)
        train_count = train_step(model, train_dataloader, optimizer, loss_criterion, args.verbose_epochs, device, train_count)
        eval_count = eval_step(model, eval_dataloader, loss_criterion, args.verbose_epochs, device, eval_count)
        torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model_checkpoint.pt'))
Ejemplo n.º 2
0
class Trainer(object):
    def __init__(self, src_domain, tgt_domain):
        self.num_epoch = 10
        self.gamma = 1.0
        print('construct dataset and dataloader...')
        train_dataset = TrainDataset(src_domain, tgt_domain)
        self.NEG_NUM = train_dataset.NEG_NUM
        self.input_dim = train_dataset.sample_dim
        self.train_loader = DataLoader(train_dataset, batch_size=32)
        print('Done!')

        self.feature_extractor = FeatureExtractor(self.input_dim)
        self.optimizer = optim.SGD(self.feature_extractor.parameters(),
                                   lr=0.1,
                                   momentum=0.9)

    def train(self):
        for i in range(self.num_epoch):
            self.train_one_epoch(i)

    def train_one_epoch(self, epoch_ind):
        loss_item = 0
        for iter, (src_pos, tgt_pos, tgt_negs) in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            src_pos_feature = self.feature_extractor(src_pos)
            tgt_pos_feature = self.feature_extractor(tgt_pos)
            tgt_negs_features = self.feature_extractor(
                tgt_negs.reshape(-1, self.input_dim))
            feature_dim = src_pos_feature.size()[1]
            tgt_negs_features = tgt_negs_features.reshape(
                -1, self.NEG_NUM, feature_dim)

            pos_sim = cosine_similarity(src_pos_feature, tgt_pos_feature)
            src_repeated_feature = src_pos_feature.unsqueeze(1).repeat(
                1, self.NEG_NUM, 1)
            neg_sims = cosine_similarity(src_repeated_feature,
                                         tgt_negs_features,
                                         dim=2)
            all_sims = torch.cat((pos_sim.unsqueeze(1), neg_sims), dim=1)

            PDQ = softmax(all_sims * self.gamma, dim=1)
            # neg_prob_sum = torch.sum(PDQ[:, 1:], 1)
            # prediction = torch.cat((PDQ[:, 0].unsqueeze(1), neg_prob_sum.unsqueeze(1)), dim=1)
            # batchsize = src_pos_feature.size()[0]
            # target = torch.zeros(batchsize).long() # 第一列是正解
            # loss = nll_loss(prediction, target)
            loss = -PDQ[:, 0].log().mean()

            loss.backward()
            self.optimizer.step()

            loss_item += loss.item()
            print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
                epoch_ind, iter, len(self.train_loader), loss.item()))
Ejemplo n.º 3
0
                             batch_size=1024,
                             shuffle=True,
                             drop_last=True)

# -------------------------------------- Training Stage ------------------------------------------- #

precision = 1e-8

feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.CrossEntropyLoss()

optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(label_predictor.parameters())

scheduler_F = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_F,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=8,
                                                   verbose=True,
                                                   eps=precision)
scheduler_C = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_C,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=8,
                                                   verbose=True,
                                                   eps=precision)
Ejemplo n.º 4
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)

    with open('config.yaml', 'r') as file:
        stream = file.read()
        config_dict = yaml.safe_load(stream)
        config = mapper(**config_dict)

    disc_model = Discriminator(input_shape=(config.data.channels,
                                            config.data.hr_height,
                                            config.data.hr_width))
    gen_model = GeneratorResNet()
    feature_extractor_model = FeatureExtractor()
    plt.ion()

    if config.distributed:
        disc_model.to(device)
        disc_model = nn.parallel.DistributedDataParallel(disc_model)
        gen_model.to(device)
        gen_model = nn.parallel.DistributedDataParallel(gen_model)
        feature_extractor_model.to(device)
        feature_extractor_model = nn.parallel.DistributedDataParallel(
            feature_extractor_model)
    elif config.gpu:
        # disc_model = nn.DataParallel(disc_model).to(device)
        # gen_model = nn.DataParallel(gen_model).to(device)
        # feature_extractor_model = nn.DataParallel(feature_extractor_model).to(device)
        disc_model = disc_model.to(device)
        gen_model = gen_model.to(device)
        feature_extractor_model = feature_extractor_model.to(device)
    else:
        return

    train_dataset = ImageDataset(config.data.path,
                                 hr_shape=(config.data.hr_height,
                                           config.data.hr_width),
                                 lr_shape=(config.data.lr_height,
                                           config.data.lr_width))
    test_dataset = ImageDataset(config.data.path,
                                hr_shape=(config.data.hr_height,
                                          config.data.hr_width),
                                lr_shape=(config.data.lr_height,
                                          config.data.lr_width))

    if config.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.batch_size,
        shuffle=config.data.shuffle,
        num_workers=config.data.workers,
        pin_memory=config.data.pin_memory,
        sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=config.data.batch_size,
                                             shuffle=config.data.shuffle,
                                             num_workers=config.data.workers,
                                             pin_memory=config.data.pin_memory)

    if args.train:
        # trainer settings
        trainer = GANTrainer(config.train, train_loader,
                             (disc_model, gen_model, feature_extractor_model))
        criterion = nn.MSELoss().to(device)
        disc_optimizer = torch.optim.Adam(disc_model.parameters(),
                                          config.train.hyperparameters.lr)
        gen_optimizer = torch.optim.Adam(gen_model.parameters(),
                                         config.train.hyperparameters.lr)
        fe_optimizer = torch.optim.Adam(feature_extractor_model.parameters(),
                                        config.train.hyperparameters.lr)

        trainer.setCriterion(criterion)
        trainer.setDiscOptimizer(disc_optimizer)
        trainer.setGenOptimizer(gen_optimizer)
        trainer.setFEOptimizer(fe_optimizer)

        # evaluator settings
        evaluator = GANEvaluator(
            config.evaluate, val_loader,
            (disc_model, gen_model, feature_extractor_model))
        # optimizer = torch.optim.Adam(disc_model.parameters(), lr=config.evaluate.hyperparameters.lr,
        # 	weight_decay=config.evaluate.hyperparameters.weight_decay)
        evaluator.setCriterion(criterion)

    if args.test:
        pass

    # Turn on benchmark if the input sizes don't vary
    # It is used to find best way to run models on your machine
    cudnn.benchmark = True
    start_epoch = 0
    best_precision = 0

    # optionally resume from a checkpoint
    if config.train.resume:
        [start_epoch,
         best_precision] = trainer.load_saved_checkpoint(checkpoint=None)

    # change value to test.hyperparameters on testing
    for epoch in range(start_epoch, config.train.hyperparameters.total_epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        if args.train:
            trainer.adjust_learning_rate(epoch)
            trainer.train(epoch)
            prec1 = evaluator.evaluate(epoch)

        if args.test:
            pass

        # remember best prec@1 and save checkpoint
        if args.train:
            is_best = prec1 > best_precision
            best_precision = max(prec1, best_precision)
            trainer.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': disc_model.state_dict(),
                    'best_precision': best_precision,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=None)