Ejemplo n.º 1
0
    def train_one_epoch(self, epoch):
        """ Train one epoch

        """
        self.audio_encoder.train()
        self.tag_encoder.train()
        self.cf_encoder.train()

        # losses
        train_pairwise_loss = 0
        train_pairwise_loss_1 = 0
        train_pairwise_loss_2 = 0
        train_pairwise_loss_3 = 0

        for batch_idx, (data, tags, cf_embeddings,
                        sound_ids) in enumerate(self.train_loader):
            self.iteration_idx += 1

            # TODO: REMOVE THAT
            # tags should already in the tag_idxs form, except for the +1 to indexes to use idx 0 for no tag
            # We probably want to add some pre-processing in data_loader.py
            # e.g. select random tags from the 100, or select random sepctrogram chunk
            """
            tag_idxs = [
                ([idx+1 for idx, val in enumerate(tag_v) if val]
                 + self.max_num_tags*[0])[:self.max_num_tags]
                for tag_v in tags
            ]

            """
            curr_labels = []
            for curr_tags in tags:
                non_neg = [i + 1 for i in curr_tags if i != -1]
                new_tags = np.zeros(self.max_num_tags)
                #new_tags[:len(non_neg)] = np.random.choice(non_neg, min(self.max_num_tags, len(non_neg)), replace=False)
                new_tags[:min(len(non_neg), 10)] = non_neg[:10]
                curr_labels.append(new_tags)
            tags_input = torch.tensor(curr_labels,
                                      dtype=torch.long).to(self.device)
            #tags_input = tags.to(self.device)
            x = data.view(-1, 1, 48, 256).to(self.device)
            cf_input = cf_embeddings.to(self.device)

            # encode
            z_audio, z_d_audio = self.audio_encoder(x)
            z_tags, attn = self.tag_encoder(tags_input,
                                            z_d_audio,
                                            mask=tags_input.unsqueeze(1))
            z_cf = self.cf_encoder(cf_input)

            # contrastive loss
            pairwise_loss_1 = contrastive_loss(z_d_audio, z_tags,
                                               self.contrastive_temperature)
            pairwise_loss_2 = contrastive_loss(z_d_audio, z_cf,
                                               self.contrastive_temperature)
            pairwise_loss_3 = contrastive_loss(z_cf, z_tags,
                                               self.contrastive_temperature)
            pairwise_loss = pairwise_loss_1 + pairwise_loss_2 + pairwise_loss_3

            # Optimize models
            """
            self.audio_opt.zero_grad()
            self.tag_opt.zero_grad()
            pairwise_loss.backward()
            self.audio_opt.step()
            self.tag_opt.step()
            """
            self.opt.zero_grad()
            pairwise_loss.backward()
            """
            clip_norm_params = {
                    'max_norm': 1.,
                    'norm_type': 2
            }
            torch.nn.utils.clip_grad_norm_(self.audio_encoder.parameters(), **clip_norm_params)
            torch.nn.utils.clip_grad_norm_(self.tag_encoder.parameters(), **clip_norm_params)
            """
            self.opt.step()

            train_pairwise_loss += pairwise_loss.item()
            train_pairwise_loss_1 += pairwise_loss_1.item()
            train_pairwise_loss_2 += pairwise_loss_2.item()
            train_pairwise_loss_3 += pairwise_loss_3.item()

            # write to tensorboard
            # These are too many data to send to tensorboard, but it can be useful for debugging/developing
            if False:
                self.tb.add_scalar("iter/contrastive_pairwise_loss",
                                   pairwise_loss.item(), self.iteration_idx)

            # logs per batch
            if batch_idx % self.log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tPairwise loss: {:.4f})'
                    .format(epoch, batch_idx * len(data),
                            len(self.train_loader.dataset),
                            100. * batch_idx / len(self.train_loader),
                            pairwise_loss.item()))

        # epoch logs
        train_pairwise_loss = train_pairwise_loss / self.length_train_dataset * self.batch_size
        train_pairwise_loss_1 = train_pairwise_loss_1 / self.length_train_dataset * self.batch_size
        train_pairwise_loss_2 = train_pairwise_loss_2 / self.length_train_dataset * self.batch_size
        train_pairwise_loss_3 = train_pairwise_loss_3 / self.length_train_dataset * self.batch_size
        print('====> Epoch: {}  Pairwise loss: {:.8f}'.format(
            epoch, train_pairwise_loss))
        print('\n')

        # tensorboard
        self.tb.add_scalar("contrastive_pairwise_loss/train/sum",
                           train_pairwise_loss, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/train/1",
                           train_pairwise_loss_1, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/train/2",
                           train_pairwise_loss_2, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/train/3",
                           train_pairwise_loss_3, epoch)

        if epoch % self.save_model_every == 0:
            torch.save(
                self.audio_encoder.state_dict(),
                str(
                    Path(self.save_model_loc, self.experiment_name,
                         f'audio_encoder_epoch_{epoch}.pt')))
            torch.save(
                self.tag_encoder.state_dict(),
                str(
                    Path(self.save_model_loc, self.experiment_name,
                         f'tag_encoder_att_epoch_{epoch}.pt')))
            torch.save(
                self.cf_encoder.state_dict(),
                str(
                    Path(self.save_model_loc, self.experiment_name,
                         f'cf_encoder_att_epoch_{epoch}.pt')))
Ejemplo n.º 2
0
    def forward(self, data):
        output = self.enc(data.view(data.size(0), 50 * 50))

        return output.to(device)  #[B 100]
Ejemplo n.º 3
0
    def val_dual_AE(self, epoch):
        """ Validation dual autoencoder

        """
        self.audio_encoder.eval()
        self.audio_decoder.eval()
        self.tag_encoder.eval()
        self.tag_decoder.eval()

        val_audio_recon_loss = 0
        val_tags_recon_loss = 0
        val_loss = 0
        val_pairwise_loss = 0

        with torch.no_grad():
            for i, (data, tags, sound_ids) in enumerate(self.val_loader):
                # replace negative values with 0 using clamp. Negative values can appear in the
                # validation set because the minmax scaler is learned on the training data only.
                x = data.view(-1, 1, 96, 96).clamp(0).to(self.device)
                tags = tags.float().clamp(0).to(self.device)

                # encode
                z_audio, z_d_audio = self.audio_encoder(x)
                z_tags, z_d_tags = self.tag_encoder(tags)

                # audio
                x_recon = self.audio_decoder(z_audio)
                audio_recon_loss = kullback_leibler(x_recon, x)

                # tags
                tags_recon = self.tag_decoder(z_tags)
                tags_recon_loss = self.tag_recon_loss_function(
                    tags_recon, tags)

                # pairwise correspondence loss
                pairwise_loss = contrastive_loss(z_d_audio, z_d_tags,
                                                 self.contrastive_temperature)

                loss = audio_recon_loss + tags_recon_loss + pairwise_loss

                val_audio_recon_loss += audio_recon_loss.item()
                val_tags_recon_loss += tags_recon_loss.item()
                val_loss += loss.item()
                val_pairwise_loss += pairwise_loss.item()

                # display some examples
                if i == 0:
                    n = min(data.size(0), 8)

                    # write files with original and reconstructed spectrograms
                    comparison = torch.cat([
                        x.flip(2)[:n],
                        x_recon.view(self.batch_size, 1, 96, 96).flip(2)[:n]
                    ])
                    save_image(
                        comparison.cpu(),
                        f'reconstructions/reconstruction_{self.experiment_name}_{epoch}.png',
                        nrow=n)

                    # print the corresponding reconstructed tags if id2tag is passed
                    if self.id2tag:
                        for idx in range(n):
                            print(
                                '\n',
                                sound_ids.cpu()[idx].tolist()[0],
                                sorted(zip(tags_recon.cpu()[idx].tolist(), [
                                    self.id2tag[str(k)]
                                    for k in range(len(tags))
                                ]),
                                       reverse=True)[:6])
                        print('\n')

        val_loss = val_loss / self.length_val_dataset * self.batch_size
        val_audio_recon_loss = val_audio_recon_loss / self.length_val_dataset * self.batch_size
        val_tags_recon_loss = val_tags_recon_loss / self.length_val_dataset * self.batch_size
        val_pairwise_loss = val_pairwise_loss / self.length_val_dataset * self.batch_size

        print('====> Val average loss: {:.4f}'.format(val_loss))
        print('recon loss audio: {:.4f}'.format(val_audio_recon_loss))
        print('recon loss tags: {:.4f}'.format(val_tags_recon_loss))
        print('pairwise loss: {:.4f}'.format(val_pairwise_loss))
        print('\n\n')

        # tensorboard
        self.tb.add_scalar("audio_recon_loss/val", val_audio_recon_loss, epoch)
        self.tb.add_scalar("tag_recon_loss/val", val_tags_recon_loss, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/val", val_pairwise_loss,
                           epoch)
        self.tb.add_scalar("total_loss/val", val_loss, epoch)
Ejemplo n.º 4
0
def test(args, epoch):
    model.eval()
    #UT_test_loss = torch.zeros(args.batch_size).to(device)
    #test_loss = torch.zeros(args.batch_size).to(device)
    bs = args.batch_size
    true_loss = 0
    UT_loss = 0
    reg_loss = 0
    true_test_loss = 0
    UT_test_loss = 0
    reg_test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = preprocess(data).to(device)
            #data = data.to(device)
            #recon_batch, mu, logvar = model(data)
            mu, logvar = model.encode(data.view(-1, 784))
            
            z1, w0, w1 = model.unscented(mu, logvar)
            """
            z2 = []
            #var = torch.exp(logvar)
            Sigma = model.batch_diag(mu, var)
            dist_z = MultivariateNormal(mu, Sigma)
            for j in range(2*mu.shape[1]):
                z2.append(dist_z.sample())
            
            for inx1, sample1 in enumerate(z1):
                recon_batch1 = model.decode(sample1)
                UT_test_loss += loss_function(recon_batch1, data, mu, var).item()
            UT_test_loss /= len(z1)
            UT_loss += UT_test_loss
            print('UT loss: ', UT_test_loss/args.batch_size)
            UT_test_loss = 0
            for inx2, sample2 in enumerate(z2):
                recon_batch2 = model.decode(sample2)
                reg_test_loss += loss_function(recon_batch2, data, mu, var).item()
            reg_test_loss /= len(z2)
            reg_loss += reg_test_loss
            print('regular sampling loss: ', reg_test_loss/args.batch_size)
            reg_test_loss = 0
            
            z3 = []
            for j in range(10000):
                z3.append(dist_z.sample())
            for inx3, sample3 in enumerate(z3):
                recon_batch3 = model.decode(sample3)
                true_test_loss += loss_function(recon_batch3, data, mu, var).item()
            true_test_loss /= len(z3)
            true_loss += true_test_loss
            print('true sampling loss: ', true_test_loss/args.batch_size)
            true_test_loss = 0
            """
            
            UT_test_loss = (1/bs)*torch.sum(model.UT_sample_loss(data.view(-1, 784), z1, mu, logvar, w0, w1)).item()
            UT_loss += UT_test_loss
            print('UT score: ', UT_test_loss)
            UT_test_loss = 0
            """
            UT_test_loss = (1/bs)*torch.sum(model.UT_sample_loss_mu(data.view(-1, 784), mu, logvar)).item()
            UT_loss += UT_test_loss
            print('UT score: ', UT_test_loss)
            UT_test_loss = 0
            """
            reg_test_loss = (1/bs)*torch.sum(model.sample_loss(data.view(-1, 784), mu, logvar, 1)).item()
            reg_loss += reg_test_loss
            print('regular sampling score: ', reg_test_loss)
            reg_test_loss = 0
            
            true_test_loss = (1/bs)*torch.sum(model.sample_loss(data.view(-1, 784), mu, logvar, 5000)).item()
            true_loss += true_test_loss
            print('true sampling score: ', true_test_loss)
            true_test_loss = 0
            
            #if i == 0:
               # n = min(data.size(0), 8)
                #comparison = torch.cat([data[:n],
                                      #recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                #save_image(comparison.cpu(),
                         #'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    UT_loss /= (len(test_loader.dataset)/bs)
    reg_loss /= (len(test_loader.dataset)/bs)
    true_loss /= (len(test_loader.dataset)/bs)
    print('====> Test set loss with regular sampling: {:.4f}'.format(reg_loss))
    print('====> Test set loss with UT: {:.4f}'.format(UT_loss))
    print('====> True test set loss: {:.4f}'.format(true_loss))
Ejemplo n.º 5
0
    def forward(self, x):    
        # Do the forward pass
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return 

# Instantiate the Adam optimizer and Cross-Entropy loss function
model = Net()   
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
  
for batch_idx, data_target in enumerate(train_loader):
    data = data_target[0]
    target = data_target[1]
    data = data.view(-1, 28 * 28)
    optimizer.zero_grad()

    # Complete a forward pass
    output = model(data)

    # Compute the loss, gradients and change the weights
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

# Set the model in eval mode
model.eval()

for i, data in enumerate(test_loader, 0):
    inputs, labels = data
Ejemplo n.º 6
0
def objective(trial):

    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer",
                                               ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

    if "checkpoint_path" in trial.user_attrs:
        checkpoint = torch.load(trial.user_attrs["checkpoint_path"])
        epoch_begin = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        accuracy = checkpoint["accuracy"]
    else:
        epoch_begin = 0

    # Get the FashionMNIST dataset.
    train_loader, valid_loader = get_mnist()

    path = f"pytorch_checkpoint/{trial.number}"
    os.makedirs(path, exist_ok=True)

    # Training of the model.
    for epoch in range(epoch_begin, EPOCHS):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            # Limiting training data for faster epochs.
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            data, target = data.view(data.size(0),
                                     -1).to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()

        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                # Limiting validation data.
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.view(data.size(0),
                                         -1).to(DEVICE), target.to(DEVICE)
                output = model(data)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

        trial.report(accuracy, epoch)

        # Save optimization status. We should save the objective value because the process may be
        # killed between saving the last model and recording the objective value to the storage.
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "accuracy": accuracy,
            },
            os.path.join(path, "model.pt"),
        )

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return accuracy
def compute_hidden_state(rnn, data, z_where_prev, z_what_prev, z_pres_prev, h_prev, c_prev):
    n = data.size(0)
    data = data.view(n, -1, 1).squeeze(2)
    rnn_input = torch.cat((data, z_where_prev, z_what_prev, z_pres_prev), 1)
    h, c = rnn(rnn_input, (h_prev, c_prev))
    return h,c
Ejemplo n.º 8
0
def train(epoch, model, train_loader, optimizer, cuda, log_interval, save_path,
          args):
    model.train()
    loss_dict = model.latest_losses()
    losses = {k + '_train': 0 for k, v in loss_dict.items()}
    epoch_losses = {k + '_train': 0 for k, v in loss_dict.items()}
    start_time = time.time()
    batch_idx, data = None, None
    for batch_idx, (data, _, _) in enumerate(train_loader):
        if data.shape[0] != args["batch_size"] or data.shape[1] != 9:
            continue
        data = data.view(args["batch_size"] * 9, 3, 224, 224)[:10, :, :, :]
        data = normalize(data)
        # print("before shape", data.shape)
        data = torch.nn.functional.pad(input=data,
                                       pad=(0, 0, 16, 16, 0, 0, 0, 0),
                                       mode='constant',
                                       value=0)
        if cuda:
            data = data.cuda()
        optimizer.zero_grad()
        outputs = nn.DataParallel(model)(data)
        loss = model.loss_function(data, *outputs)
        loss.backward()
        optimizer.step()
        latest_losses = model.latest_losses()
        for key in latest_losses:
            losses[key + '_train'] += float(latest_losses[key])
            epoch_losses[key + '_train'] += float(latest_losses[key])

        if batch_idx % log_interval == 0:
            for key in latest_losses:
                losses[key + '_train'] /= log_interval
            loss_string = ' '.join(
                ['{}: {:.6f}'.format(k, v) for k, v in losses.items()])
            logging.info(
                'Train Epoch: {epoch} [{batch:5d}/{total_batch} ({percent:2d}%)]   time:'
                ' {time:3.2f}   {loss}'.format(
                    epoch=epoch,
                    batch=batch_idx * len(data),
                    total_batch=len(train_loader) * len(data),
                    percent=int(100. * batch_idx / len(train_loader)),
                    time=time.time() - start_time,
                    loss=loss_string))
            start_time = time.time()
            for key in latest_losses:
                losses[key + '_train'] = 0
        if batch_idx == (len(train_loader) - 1):
            save_reconstructed_images(data, epoch, outputs[0], save_path,
                                      'reconstruction_train')
        if args["dataset"] == 'imagenet' and batch_idx * len(data) > 25000:
            break

    for key in epoch_losses:
        if args["dataset"] != 'imagenet':
            epoch_losses[key] /= (len(train_loader.dataset) /
                                  train_loader.batch_size)
        else:
            epoch_losses[key] /= (len(train_loader.dataset) /
                                  train_loader.batch_size)
    loss_string = '\t'.join(
        ['{}: {:.6f}'.format(k, v) for k, v in epoch_losses.items()])
    logging.info('====> Epoch: {} {}'.format(epoch, loss_string))
    model.print_atom_hist(outputs[3])
    return epoch_losses
Ejemplo n.º 9
0
    def forward(self, data):
        data = data.view(-1, self.D_in)
        h1 = self.relu(self.fc1(data))
        h2 = self.relu(self.fc2(h1))

        return self.fc3(h2)
def train_CEDA_gmm_out(model,
                       train_loader,
                       ood_loader,
                       optimizer,
                       epoch,
                       lam=1.,
                       verbose=10):
    criterion = nn.NLLLoss()

    model.train()

    train_loss = 0
    likelihood_loss = 0
    correct = 0
    margin = np.log(4.)

    if ood_loader is not None:
        ood_loader.dataset.offset = np.random.randint(len(ood_loader.dataset))
        ood_loader_iter = iter(ood_loader)

    p_in = torch.tensor(1. / (1. + lam), dtype=torch.float).cuda()
    p_out = torch.tensor(lam, dtype=torch.float).cuda() * p_in

    log_p_in = p_in.log()
    log_p_out = p_out.log()

    start = time.time()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()

        noise = next(ood_loader_iter)[0].cuda()

        optimizer.zero_grad()

        full_data = torch.cat([data, noise], 0)
        full_out = model(full_data)
        full_out = F.log_softmax(full_out, dim=1)

        output = full_out[:data.shape[0]]
        output_adv = full_out[data.shape[0]:]

        like_in_in = torch.logsumexp(model.mm(data.view(data.shape[0], -1)), 0)
        like_out_in = torch.logsumexp(model.mm(noise.view(noise.shape[0], -1)),
                                      0)

        like_in_out = torch.logsumexp(
            model.mm_out(data.view(data.shape[0], -1)), 0)
        like_out_out = torch.logsumexp(
            model.mm_out(noise.view(noise.shape[0], -1)), 0)

        loss1 = criterion(output, target)
        loss2 = -output_adv.mean()
        loss3 = -torch.logsumexp(
            torch.stack([log_p_in + like_in_in, log_p_out + like_in_out], 0),
            0).mean()
        loss4 = -torch.logsumexp(
            torch.stack([log_p_in + like_out_in, log_p_out + like_out_out], 0),
            0).mean()

        loss = p_in * (loss1 + loss3) + p_out * (loss2 + loss4)

        loss.backward()
        optimizer.step()

        likelihood_loss += loss3.item()
        train_loss += loss.item()
        _, predicted = output.max(1)
        correct += predicted.eq(target).sum().item()

        threshold = model.mm.logvar.max() + margin
        idx = model.mm_out.logvar < threshold
        model.mm_out.logvar.data[idx] = threshold

        if (batch_idx % verbose == 0) and verbose > 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    print('Time: ', time.time() - start)
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor()])),
                                            batch_size=BatchSize)
test_loading = torch.utils.data.DataLoader(
    datasets.MNIST('data',
                   train=False,
                   transform=transforms.Compose([transforms.ToTensor()])),
    batch_size=BatchSize)

#%% Arrange training data properly
input_data = []
target_data = []

for _, (data, target) in enumerate(train_loading):
    data = data.view(-1, 784)
    data = data.numpy()
    input_data.append(data)

full_input = np.array(input_data[0])
full_input = torch.Tensor(full_input)

for i in range(1, 938):
    intermed = np.array(input_data[i])
    intermed = torch.Tensor(intermed)
    full_input = torch.cat((full_input, intermed), 0)

#%% Arrange testing data properly
input_data = []
for _, (data, target) in enumerate(test_loading):
    data = data.view(-1, 784)
Ejemplo n.º 12
0
def train_and_get_data_RBME(dataset, n_vis, n_hid, k, n_epochs, batch_size, lr,
                            momentum, filename):
    # create a Restricted Boltzmann Machine

    global list0
    list0 = []
    listloss = []

    model = RBM(n_vis, n_hid, k)
    dataset1 = CustomDataset(dataset)
    train_loader = DataLoader(dataset1, batch_size, shuffle=True)
    # optimizer
    train_op = optim.SGD(model.parameters(), 0, 0)
    # train the RBM model
    model.train()

    # sampling for not learnt model
    loss0 = []
    for _, (data) in enumerate(train_loader):
        v, v_gibbs = model(data.view(-1, 64))
        loss = model.free_energy(v) - model.free_energy(v_gibbs)

        loss0.append(loss.item())

        train_op.zero_grad()
        loss.backward()
        train_op.step()
    listloss.append(np.mean(loss0))
    # optimizer
    train_op = optim.SGD(model.parameters(), lr, momentum)
    # train_op = optim.Adam(model.parameters(), lr)

    # train the RBM model
    model.train()
    sign_changed = 0

    for epoch in range(n_epochs):
        loss_ = []
        for _, (data) in enumerate(train_loader):
            v, v_gibbs = model(data.view(-1, 64))
            loss = model.free_energy(v) - model.free_energy(v_gibbs)
            loss_.append(loss.item())
            train_op.zero_grad()
            loss.backward()
            train_op.step()
        listloss.append(np.mean(loss_))

        print('Epoch %3.d | Loss=%6.4f' % (
            epoch + 1,
            np.mean(loss_),
        ))

        if np.abs(listloss[-1]) < 0.1:
            print("Earlystopping : Loss<0.1")
            break
    end_epoch = epoch

    states_in_epoch = []
    for e in range(end_epoch + 1):
        for i in range(int(e * len(list0) / (end_epoch + 1)),
                       int((e + 1) * len(list0) / (end_epoch + 1))):
            for j in range(batch_size):
                states_in_epoch.append(str(list0[i][j].tolist()))
    a, b = get_listmk(states_in_epoch)
    Hs = get_H_s(a, b)
    Hk = get_H_k(a, b)
    with open(
            'models/%s_model_n_hid=%s_%s.pkl' %
        (str(datetime.today())[:10], n_hid, filename), 'wb') as f:
        pkl.dump(model, f)
    with open(
            'data/%s_data_n_hid=%s_%s.pkl' %
        (str(datetime.today())[:10], n_hid, filename), 'wb') as f:
        pkl.dump([a, b, Hs, Hk], f)
Ejemplo n.º 13
0
def apply_batch_norm(batch_norm, data):
    original_shape = data.shape
    data_list = data.view(-1, batch_norm.num_features)
    data_list = batch_norm(data_list)
    return data_list.view(original_shape)
Ejemplo n.º 14
0
    def val(self, epoch):
        """ Validation

        """
        # A little bit a code repeat here...
        self.audio_encoder.eval()
        self.tag_encoder.eval()
        self.cf_encoder.eval()

        val_pairwise_loss = 0
        val_pairwise_loss_1 = 0
        val_pairwise_loss_2 = 0
        val_pairwise_loss_3 = 0

        with torch.no_grad():
            for i, (data, tags, cf_embeddings,
                    sound_ids) in enumerate(self.val_loader):

                curr_labels = []
                for curr_tags in tags:
                    non_neg = [i + 1 for i in curr_tags if i != -1]
                    new_tags = np.zeros(self.max_num_tags)
                    #new_tags[:len(non_neg)] = np.random.choice(non_neg, min(self.max_num_tags, len(non_neg)), replace=False)
                    new_tags[:min(len(non_neg), 10)] = non_neg[:10]
                    curr_labels.append(new_tags)
                tags_input = torch.tensor(curr_labels,
                                          dtype=torch.long).to(self.device)

                x = data.view(-1, 1, 48, 256).to(self.device)
                cf_input = cf_embeddings.to(self.device)

                # encode
                z_audio, z_d_audio = self.audio_encoder(x)
                z_tags, attn = self.tag_encoder(tags_input,
                                                z_d_audio,
                                                mask=tags_input.unsqueeze(1))
                z_cf = self.cf_encoder(cf_input)

                # pairwise correspondence loss
                #pairwise_loss = contrastive_loss(z_d_audio, z_tags, self.contrastive_temperature)
                #pairwise_loss = contrastive_loss(z_d_audio, z_cf, self.contrastive_temperature)
                #pairwise_loss = contrastive_loss(z_d_audio, z_fusion, self.contrastive_temperature)
                # contrastive loss
                pairwise_loss_1 = contrastive_loss(
                    z_d_audio, z_tags, self.contrastive_temperature)
                pairwise_loss_2 = contrastive_loss(
                    z_d_audio, z_cf, self.contrastive_temperature)
                pairwise_loss_3 = contrastive_loss(
                    z_cf, z_tags, self.contrastive_temperature)
                pairwise_loss = pairwise_loss_1 + pairwise_loss_2 + pairwise_loss_3

                val_pairwise_loss += pairwise_loss.item()
                val_pairwise_loss_1 += pairwise_loss_1.item()
                val_pairwise_loss_2 += pairwise_loss_2.item()
                val_pairwise_loss_3 += pairwise_loss_3.item()

        val_pairwise_loss = val_pairwise_loss / self.length_val_dataset * self.batch_size
        val_pairwise_loss_1 = val_pairwise_loss_1 / self.length_val_dataset * self.batch_size
        val_pairwise_loss_2 = val_pairwise_loss_2 / self.length_val_dataset * self.batch_size
        val_pairwise_loss_3 = val_pairwise_loss_3 / self.length_val_dataset * self.batch_size

        print('====> Val average pairwise loss: {:.4f}'.format(
            val_pairwise_loss))
        print('\n\n')

        # tensorboard
        self.tb.add_scalar("contrastive_pairwise_loss/val/sum",
                           val_pairwise_loss, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/val/1",
                           val_pairwise_loss_1, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/val/2",
                           val_pairwise_loss_2, epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/val/3",
                           val_pairwise_loss_3, epoch)
        if not (math.isinf(val_pairwise_loss)
                or math.isinf(val_pairwise_loss)):
            if val_pairwise_loss < self.curr_min_val:
                self.curr_min_val = val_pairwise_loss
                torch.save(
                    self.audio_encoder.state_dict(),
                    str(
                        Path(self.save_model_loc, self.experiment_name,
                             f'audio_encoder_epoch_best.pt')))
                torch.save(
                    self.tag_encoder.state_dict(),
                    str(
                        Path(self.save_model_loc, self.experiment_name,
                             f'tag_encoder_att_epoch_best.pt')))
                torch.save(
                    self.cf_encoder.state_dict(),
                    str(
                        Path(self.save_model_loc, self.experiment_name,
                             f'cf_encoder_att_epoch_best.pt')))
Ejemplo n.º 15
0
def train(model, train_loader, n_epochs, lr, momentum):
    """
    Train a RBM model.
    Args:
        model: The model.
        train_loader (DataLoader): The data loader.
        n_epochs (int, optional): The number of epochs. Defaults to 20.
        lr (Float, optional): The learning rate. Defaults to 0.01.
    Returns:
        The trained model.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # optimizer
    train_op = optim.SGD(model.parameters(), 0, 0)
#     model
    # train the RBM model
    model.train()
    
    E_origin=[]
    # sampling for not learnt model
    loss0=[]
    for _, (data) in enumerate(train_loader):
#         v, v_gibbs = model(data.view(-1, 784))
        v, v_gibbs = model(data.view(-1, 64))
        loss = model.free_energy(v) - model.free_energy(v_gibbs)

        loss0.append(loss.item())

        train_op.zero_grad()
        loss.backward()
        train_op.step()
    listloss.append(np.mean(loss0))
    
#     for param in model.parameters():
#         None
#     data1=list(torch.flatten(param.data))
#     listparam.append(data1)

    # optimizer
    train_op = optim.SGD(model.parameters(), lr, momentum)
#     train_op = optim.Adam(model.parameters(), lr)

    # train the RBM model
    model.train()
    sign_changed=0
#     E_origin_mean=np.mean(E_origin)
#     E_origin_std=np.std(E_origin)

    for epoch in range(n_epochs):
        loss_ = []
        v_generated=[]
        E_generated=[]
        for _, (data) in enumerate(train_loader):
            v, v_gibbs = model(data.view(-1, 64))
            loss = model.free_energy(v) - model.free_energy(v_gibbs)
            loss_.append(loss.item())
            train_op.zero_grad()
            loss.backward()
            train_op.step()
            v_generated.append(v_gibbs)
        listloss.append(np.mean(loss_))
        for i in range(len(v_generated)):
            for j in range(len(v_generated[0])):
                E_generated.append(energy(v_generated[i][j]))
        print('Epoch %3.d | Loss=%6.4f | E_gen_mean=%6.6f | E_gen_std=%6.6f' % (epoch+1, np.mean(loss_), np.mean(E_generated), np.std(E_generated)))
        
#         # save parameters
#         for param in model.parameters():
#             None
#         data1=list(torch.flatten(param.data))
#         listparam.append(data1)

#         if listloss[-1]*listloss[-2] <= 0:
#             sign_changed+=1
#             print('sign changed:%d'%(sign_changed))
#         if sign_changed > 3 and np.abs(listloss[-1])<0.1:
        if np.abs(listloss[-1])<0.1:
            print("Earlystopping : Loss<0.1")
            break
    end_epoch=epoch

    return model, end_epoch, v_generated[0]
Ejemplo n.º 16
0
from __future__ import print_function
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
# %matplotlib inline
use_cuda = False
batch_size = 32
latent_size = 20  # z dim
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=batch_size,
                                          shuffle=True,
                                          **kwargs)

def to_var(x):
    x = Variable(x)
    if use_cuda:
        x = x.cuda()
Ejemplo n.º 17
0
 def forward(self, data):
     return self.model(data.view(-1, 28 * 28))
Ejemplo n.º 18
0
    model.eval()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    valdir = args.data + args.distortion_name
    val_loader = torch.utils.data.DataLoader(
        VideoFolder(root=valdir, transform=transforms.Compose([transforms.ToTensor(), normalize])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    predictions, ranks = [], []
    with torch.no_grad():

        for data, target in val_loader:
            num_vids = data.size(0)
            data = data.view(-1, 3, 224, 224).cuda()

            output = model(data)

            for vid in output.view(num_vids, -1, 1000):
                predictions.append(vid.argmax(1).to('cpu').numpy())
                ranks.append([np.uint16(rankdata(-frame, method='ordinal')) for frame in vid.to('cpu').numpy()])

    ranks = np.asarray(ranks)

    fr = flip_prob(predictions, args.distortion_name)
    t5d = ranking_dist(ranks, args.distortion_name, mode='top5')

    print('Computing Metrics\n')
    print('Flipping Prob\t{:.5f}'.format(fr))
    print('Top5 Distance\t{:.5f}'.format(t5d))
Ejemplo n.º 19
0
        # one_hot_vextor = self.vector
        # self.vector = [0]*self.max
        # return torch.LongTensor(one_hot_vextor)
        return torch.tensor(label)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        keypoint = np.asarray(self.datas[idx]).astype(np.float32).reshape(
            frm, int(kps / 2), 2)  #[30*34]#.reshape(-1,1,17,2)
        # keypoint = keypoint[:,np.newaxis]#[30*34,1]
        keypoint = torch.from_numpy(keypoint).unsqueeze(1)
        one_hot_label = self.get_one_hot_num(self.labels[idx])

        return (keypoint, one_hot_label)


if __name__ == '__main__':
    datas = [[0] * 30 * 34 for _ in range(6)]
    #print(len(datas[0]))
    labels = [0, 0, 1, 1, 0, 1]
    input_channels = (1, 17, 2)
    dataset = ConvLstmLoader(datas, labels, 2)
    dataloader = DataLoader(dataset)
    dataiter = iter(dataloader)
    data, label = dataiter.next()
    data = data.view(-1, 30, input_channels[0], input_channels[1],
                     input_channels[2])
    print(data.shape)
                      train=True,
                      transform=transforms.Compose([transforms.ToTensor()]),
                      download=True)
 n_channel = 1
 dataloader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batchSize,
                                          shuffle=True)
 print("=====> 构建VAE")
 vae = VAE().to(device)
 vae.load_state_dict(torch.load('./VAE-WGANGP-VAE_v2.pth'))
 pos = []
 label = []
 for epoch in range(nepoch):
     for i, (data, lab) in enumerate(dataloader, 0):
         num_img = data.size(0)
         data = data.view(num_img, 1, 28, 28).to(device)  # 将图片展开为28*28=784
         x, mean, logstd = vae(data)  # 将真实图片放入判别器中
         pos.append(mean)
         label.append(lab)
         if (i == 100):
             break
 pos = torch.cat(pos)
 label = torch.cat(label)
 print(pos.shape)
 print(label.shape)
 for i in range(10):
     plt.scatter(pos[label == i][:, 0].detach().numpy(),
                 pos[label == i][:, 1].detach().numpy(),
                 alpha=0.5,
                 label=i)
 plt.title('VAE-WGANGP-MNIST')
Ejemplo n.º 21
0
def train(model,
          optimizer,
          train_loader,
          loss_func,
          epochs=1,
          show_prog=100,
          summary=None,
          test_loader=None,
          scheduler=None,
          beta=1):

    if summary:
        writer = SummaryWriter()

    ohc = OneHotEncoder(sparse=False)

    # fit to some dummy data to prevent errors later
    ohc.fit(np.arange(0, 10).reshape(10, 1))

    b_size = float(train_loader.batch_size)

    #writer.add_graph_onnx(model)

    # add an initial values for the likelihoods to prevent weird glitches in tensorboard
    if summary:
        if test_loader:
            test_loss = get_loss(model, test_loader, loss_func, ohc)
            writer.add_scalar('loss/ave_test_loss_per_datapoint', -test_loss,
                              0)
        train_loss = get_loss(model, train_loader, loss_func, ohc)
        writer.add_scalar('loss/ave_loss_per_datapoint', -train_loss, 0)

    model.train()
    for i in tqdm(range(epochs)):
        if scheduler:
            scheduler.step()
            print(optimizer.state_dict()['param_groups'][0])
        for batch_idx, (data) in enumerate(train_loader):

            if type(data) == list:
                label = data[1]
                data = data[0]

            #n_iter = (i*len(train_loader))+batch_idx
            n_iter = (i * len(train_loader) * b_size) + batch_idx * b_size

            data = Variable(data.view(train_loader.batch_size, -1),
                            requires_grad=False)
            if model.conditional:
                label = Variable(
                    torch.Tensor(
                        ohc.transform(label.numpy().reshape(len(label), 1))))
                data = torch.cat([data, label], dim=1)
            data = data.cuda()  # Make it GPU friendly
            optimizer.zero_grad(
            )  # reset the optimzer so we don't have grad data from the previous batch
            dec_m, dec_v, enc_m, enc_v = model(data)  # forward pass
            if model.conditional:
                data_o = data[:, :-10]
            else:
                data_o = data
            loss = loss_func(enc_m, enc_v, data_o, dec_m, dec_v, model,
                             beta)  # get the loss
            if summary:
                # write the negative log likelihood ELBO per data point to tensorboard
                #pdb.set_trace()
                writer.add_scalar('loss/ave_loss_per_datapoint',
                                  -loss.data[0] / b_size, n_iter)
                #w_s = torch.cat([torch.cat(layer.weight.data) for layer in model.children()]).abs().sum()
                #writer.add_scalar('sum of NN weights', w_s, n_iter) # check for regularisation
            loss.backward()  # back prop the loss
            optimizer.step(
            )  # increment the optimizer based on the loss (a.k.a update params)
            if batch_idx % show_prog == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    i, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data[0]))
                if summary:
                    writer.add_image('real_image',
                                     data_o[1].view(-1, model.h,
                                                    model.w), n_iter)
                    a, _, _, _ = model(data[1:2].cuda())
                    writer.add_image('reconstruction',
                                     a.view(-1, model.h, model.w), n_iter)
                    if model.conditional:
                        p = np.random.randint(0, 10)
                        num = [0] * 10
                        num[p] = 1
                        num = Variable(torch.Tensor(num)).cuda()
                        b, _ = model.decode(torch.cat([model.sample(), num]))
                    else:
                        b, _ = model.decode(model.sample())
                    writer.add_image('from_noise',
                                     b.view(-1, model.h, model.w), n_iter)
        if test_loader and summary:
            test_loss = get_loss(model, test_loader, loss_func, ohc)
            writer.add_scalar('loss/ave_test_loss_per_datapoint', -test_loss,
                              n_iter + b_size)
Ejemplo n.º 22
0
n_hid = 8192  # number of neurons in the hidden layer
n_vis = 27  # input size
k = 1  # The number of Gibbs sampling

# load data
original_data = np.load('../data/test_set.npy')
clean_data = data_processing.missing_values(original_data, method='zeros')
training_data = torch.from_numpy(clean_data)
train_loader = torch.utils.data.DataLoader(training_data,
                                           batch_size=batch_size,
                                           shuffle=False)

# create a Restricted Boltzmann Machine
model = RBM(n_vis=n_vis, n_hid=n_hid, k=k)
train_op = optim.Adam(model.parameters(), lr)

# train the model
model.train()
for epoch in range(n_epochs):
    loss_ = []
    for i, data_target in enumerate(train_loader):
        data, target = torch.split(data_target, 27, dim=1)
        data = data.float()
        target = target.float()
        v, v_gibbs = model(data.view(-1, 27))
        loss = model.free_energy(v) - model.free_energy(v_gibbs)
        loss_.append(loss.item())
        train_op.zero_grad()
        loss.backward()
        train_op.step()
    print('Epoch %d\t Loss=%.4f' % (epoch, np.mean(loss_)))
Ejemplo n.º 23
0
 def reconstruction_loss(self, data, reconstructions):
     loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1),
                          data.view(reconstructions.size(0), -1))
     return loss * 0.0005
Ejemplo n.º 24
0
             h1_sample] = self.gibbs_hvh(h0_sample)
            [pre_sig_h1, h1_mean, h0_sample, pre_sig_v1, v1_mean,
             v1_sample] = self.gibbs_vhv(v1_sample)

        nv_sample = v1_sample[-1]
        cost = torch.mean(self.free_energy(self.input_data)) - torch.mean(
            self.free_energy(nv_sample))


rbm = RBM(n_vis=784, n_hid=500)
train_op = optim.SGD(rbm.parameters(), 0.1)

for epoch in range(4):
    loss_ = []
    for _, (data, target) in enumerate(train_loader):
        sample_data = Variable(data.view(-1, 784)).bernoulli()
        v, v1 = rbm(sample_data)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        loss_.append(loss.data[0])
        train_op.zero_grad()
        loss.backward()
        train_op.step()

    print np.mean(loss_)


def monitoring(file_name, img):
    imgplot = np.transpose(img.numpy(), (1, 2, 0))
    f = "./%s.png" % file_name
    plt.imshow(imgplot)
    plt.imsave(f, imgplot)
Ejemplo n.º 25
0
 def forward(self, data):
     data = data.view(-1, 28 * 28)
     for layer, dropout in zip(self.layers, self.dropouts):
         data = F.relu(layer(data))
         data = dropout(data)
     return F.log_softmax(self.layers[-1](data), dim=1)
Ejemplo n.º 26
0
    def train_one_epoch(self, epoch):
        """ Train one epoch

        """
        self.audio_encoder.train()
        self.mo_encoder.train()

        # losses
        train_pairwise_loss = 0
        train_pairwise_loss_1 = 0
        train_pairwise_loss_2 = 0
        train_pairwise_loss_3 = 0

        for batch_idx, (data, tags, cf_embeddings,
                        sound_ids) in enumerate(self.train_loader):
            self.iteration_idx += 1

            # TODO: REMOVE THAT
            # tags should already in the tag_idxs form, except for the +1 to indexes to use idx 0 for no tag
            # We probably want to add some pre-processing in data_loader.py
            # e.g. select random tags from the 100, or select random sepctrogram chunk
            """
            tag_idxs = [
                ([idx+1 for idx, val in enumerate(tag_v) if val]
                 + self.max_num_tags*[0])[:self.max_num_tags]
                for tag_v in tags
            ]
            """

            x = data.view(-1, 1, 48, 256).to(self.device)

            if self.encoder_type == "gnr" or self.encoder_type == "MF_gnr":
                curr_labels = []
                for curr_tags in tags:
                    non_neg = [i for i in curr_tags if i != -1]
                    new_tags = np.zeros(self.size_voc)
                    new_tags[non_neg] = 1
                    curr_labels.append(new_tags)
                tags_input = torch.tensor(curr_labels,
                                          dtype=torch.float).to(self.device)

            if self.encoder_type == "MF" or self.encoder_type == "MF_gnr":
                cf_input = cf_embeddings.to(self.device)

            # encode
            z_audio, z_d_audio = self.audio_encoder(x)
            z_cf, z_tags = self.mo_encoder(z_d_audio)

            # contrastive loss
            if self.encoder_type == "gnr" or self.encoder_type == "MF_gnr":
                bceloss = torch.nn.BCEWithLogitsLoss()
                pairwise_loss_1 = bceloss(z_tags, tags_input)
                pairwise_loss = pairwise_loss_1
            if self.encoder_type == "MF" or self.encoder_type == "MF_gnr":
                mseloss = torch.nn.MSELoss()
                pairwise_loss_2 = mseloss(z_cf, cf_input)
                pairwise_loss = pairwise_loss_2
            if self.encoder_type == "MF_gnr":
                pairwise_loss = pairwise_loss_1 + pairwise_loss_2

            # Optimize models
            self.opt.zero_grad()
            pairwise_loss.backward()
            self.opt.step()

            train_pairwise_loss += pairwise_loss.item()
            if self.encoder_type == "gnr":
                train_pairwise_loss_1 += pairwise_loss_1.item()
            if self.encoder_type == "MF":
                train_pairwise_loss_2 += pairwise_loss_2.item()

            # write to tensorboard
            # These are too many data to send to tensorboard, but it can be useful for debugging/developing
            if False:
                self.tb.add_scalar("iter/contrastive_pairwise_loss",
                                   pairwise_loss.item(), self.iteration_idx)

            # logs per batch
            if batch_idx % self.log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tPairwise loss: {:.4f})'
                    .format(epoch, batch_idx * len(data),
                            len(self.train_loader.dataset),
                            100. * batch_idx / len(self.train_loader),
                            pairwise_loss.item()))

        # epoch logs
        train_pairwise_loss = train_pairwise_loss / self.length_train_dataset * self.batch_size
        if self.encoder_type == "gnr":
            train_pairwise_loss_1 = train_pairwise_loss_1 / self.length_train_dataset * self.batch_size
        if self.encoder_type == "MF":
            train_pairwise_loss_2 = train_pairwise_loss_2 / self.length_train_dataset * self.batch_size
        print('====> Epoch: {}  Pairwise loss: {:.8f}'.format(
            epoch, train_pairwise_loss))
        print('\n')

        # tensorboard
        self.tb.add_scalar("contrastive_pairwise_loss/train/sum",
                           train_pairwise_loss, epoch)
        if self.encoder_type == "gnr":
            self.tb.add_scalar("contrastive_pairwise_loss/train/1",
                               train_pairwise_loss_1, epoch)
        if self.encoder_type == "MF":
            self.tb.add_scalar("contrastive_pairwise_loss/train/2",
                               train_pairwise_loss_2, epoch)

        if epoch % self.save_model_every == 0:
            torch.save(
                self.audio_encoder.state_dict(),
                str(
                    Path(self.save_model_loc, self.experiment_name,
                         f'audio_encoder_epoch_{epoch}.pt')))
            torch.save(
                self.mo_encoder.state_dict(),
                str(
                    Path(self.save_model_loc, self.experiment_name,
                         f'mo_encoder_epoch_{epoch}.pt')))
Ejemplo n.º 27
0
    def train_one_epoch_dual_AE(self, epoch):
        """ Train one epoch

        """
        self.audio_encoder.train()
        self.audio_decoder.train()
        self.tag_encoder.train()
        self.tag_decoder.train()

        # losses
        train_audio_recon_loss = 0
        train_tags_recon_loss = 0
        train_loss = 0
        train_pairwise_loss = 0

        for batch_idx, (data, tags, sound_ids) in enumerate(self.train_loader):
            self.iteration_idx += 1

            x = data.view(-1, 1, 96, 96).to(self.device)
            tags = tags.float().to(self.device)

            # encode
            z_audio, z_d_audio = self.audio_encoder(x)
            z_tags, z_d_tags = self.tag_encoder(tags)

            # audio reconstruction
            x_recon = self.audio_decoder(z_audio)
            audio_recon_loss = kullback_leibler(x_recon, x)

            # tags reconstruction
            tags_recon = self.tag_decoder(z_tags)
            tags_recon_loss = self.tag_recon_loss_function(tags_recon, tags)

            # contrastive loss
            pairwise_loss = contrastive_loss(z_d_audio, z_d_tags,
                                             self.contrastive_temperature)

            # total loss
            loss = audio_recon_loss + tags_recon_loss + pairwise_loss

            # Optimize models
            self.audio_dae_opt.zero_grad()
            self.tag_dae_opt.zero_grad()
            audio_recon_loss.mul(
                self.audio_loss_weight).backward(retain_graph=True)
            tags_recon_loss.mul(
                self.tag_loss_weight).backward(retain_graph=True)
            if self.contrastive_loss_weight:
                pairwise_loss.mul(self.contrastive_loss_weight).backward()
            self.audio_dae_opt.step()
            self.tag_dae_opt.step()

            train_audio_recon_loss += audio_recon_loss.item()
            train_tags_recon_loss += tags_recon_loss.item()
            train_loss += loss.item()
            train_pairwise_loss += pairwise_loss.item()

            # write to tensorboard
            if False:
                self.tb.add_scalar("iter/audio_recon_loss",
                                   audio_recon_loss.item(), self.iteration_idx)
                self.tb.add_scalar("iter/tag_recon_loss",
                                   tags_recon_loss.item(), self.iteration_idx)
                self.tb.add_scalar("iter/contrastive_pairwise_loss",
                                   pairwise_loss.item(), self.iteration_idx)
                self.tb.add_scalar("iter/total_loss", loss.item(),
                                   self.iteration_idx)

            # logs per batch
            if batch_idx % self.log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f} Audio Recon: {:.4f}, '
                    'Tags Recon: {:.4f},  Pairwise: {:.4f})'.format(
                        epoch, batch_idx * len(data),
                        len(self.train_loader.dataset),
                        100. * batch_idx / len(self.train_loader), loss.item(),
                        audio_recon_loss.item(), tags_recon_loss.item(),
                        pairwise_loss.item()))

        # epoch logs
        train_loss = train_loss / self.length_train_dataset * self.batch_size
        train_audio_recon_loss = train_audio_recon_loss / self.length_train_dataset * self.batch_size
        train_tags_recon_loss = train_tags_recon_loss / self.length_train_dataset * self.batch_size
        train_pairwise_loss = train_pairwise_loss / self.length_train_dataset * self.batch_size

        print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss))
        print('recon loss audio: {:.4f}'.format(train_audio_recon_loss))
        print('recon loss tags: {:.4f}'.format(train_tags_recon_loss))
        print('pairwise loss: {:.8f}'.format(train_pairwise_loss))
        print('\n')

        # tensorboard
        self.tb.add_scalar("audio_recon_loss/train", train_audio_recon_loss,
                           epoch)
        self.tb.add_scalar("tag_recon_loss/train", train_tags_recon_loss,
                           epoch)
        self.tb.add_scalar("contrastive_pairwise_loss/train",
                           train_pairwise_loss, epoch)
        self.tb.add_scalar("total_loss/train", train_loss, epoch)

        if epoch % self.save_model_every == 0:
            torch.save(
                self.audio_encoder.state_dict(),
                str(
                    Path(f'saved_models', self.experiment_name,
                         f'audio_encoder_epoch_{epoch}.pt')))
            torch.save(
                self.audio_decoder.state_dict(),
                str(
                    Path(f'saved_models', self.experiment_name,
                         f'audio_decoder_epoch_{epoch}.pt')))
            torch.save(
                self.tag_encoder.state_dict(),
                str(
                    Path(f'saved_models', self.experiment_name,
                         f'tag_encoder_epoch_{epoch}.pt')))
            torch.save(
                self.tag_decoder.state_dict(),
                str(
                    Path(f'saved_models', self.experiment_name,
                         f'tag_decoder_epoch_{epoch}.pt')))
Ejemplo n.º 28
0
    def val(self, epoch):
        """ Validation

        """
        # A little bit a code repeat here...
        self.audio_encoder.eval()
        self.mo_encoder.eval()

        val_pairwise_loss = 0
        val_pairwise_loss_1 = 0
        val_pairwise_loss_2 = 0
        val_pairwise_loss_3 = 0

        with torch.no_grad():
            for i, (data, tags, cf_embeddings,
                    sound_ids) in enumerate(self.val_loader):

                x = data.view(-1, 1, 48, 256).to(self.device)

                if self.encoder_type == "gnr" or self.encoder_type == "MF_gnr":
                    curr_labels = []
                    for curr_tags in tags:
                        non_neg = [i for i in curr_tags if i != -1]
                        new_tags = np.zeros(self.size_voc)
                        new_tags[non_neg] = 1
                        curr_labels.append(new_tags)
                    tags_input = torch.tensor(
                        curr_labels, dtype=torch.float).to(self.device)

                if self.encoder_type == "MF" or self.encoder_type == "MF_gnr":
                    cf_input = cf_embeddings.to(self.device)

                # encode
                z_audio, z_d_audio = self.audio_encoder(x)
                z_cf, z_tags = self.mo_encoder(z_d_audio)

                # contrastive loss
                if self.encoder_type == "gnr" or self.encoder_type == "MF_gnr":
                    bceloss = torch.nn.BCEWithLogitsLoss()
                    pairwise_loss_1 = bceloss(z_tags, tags_input)
                    pairwise_loss = pairwise_loss_1
                if self.encoder_type == "MF" or self.encoder_type == "MF_gnr":
                    mseloss = torch.nn.MSELoss()
                    pairwise_loss_2 = mseloss(z_cf, cf_input)
                    pairwise_loss = pairwise_loss_2
                if self.encoder_type == "MF_gnr":
                    pairwise_loss = pairwise_loss_1 + pairwise_loss_2

                val_pairwise_loss += pairwise_loss.item()
                if self.encoder_type == "gnr":
                    val_pairwise_loss_1 += pairwise_loss_1.item()
                if self.encoder_type == "MF":
                    val_pairwise_loss_2 += pairwise_loss_2.item()

        val_pairwise_loss = val_pairwise_loss / self.length_val_dataset * self.batch_size
        if self.encoder_type == "gnr":
            val_pairwise_loss_1 = val_pairwise_loss_1 / self.length_val_dataset * self.batch_size
        if self.encoder_type == "MF":
            val_pairwise_loss_2 = val_pairwise_loss_2 / self.length_val_dataset * self.batch_size

        print('====> Val average pairwise loss: {:.4f}'.format(
            val_pairwise_loss))
        print('\n\n')

        # tensorboard
        self.tb.add_scalar("contrastive_pairwise_loss/val/sum",
                           val_pairwise_loss, epoch)
        if self.encoder_type == "gnr":
            self.tb.add_scalar("contrastive_pairwise_loss/val/1",
                               val_pairwise_loss_1, epoch)
        if self.encoder_type == "MF":
            self.tb.add_scalar("contrastive_pairwise_loss/val/2",
                               val_pairwise_loss_2, epoch)
        if not (math.isinf(val_pairwise_loss)
                or math.isinf(val_pairwise_loss)):
            if val_pairwise_loss < self.curr_min_val:
                self.curr_min_val = val_pairwise_loss
                torch.save(
                    self.audio_encoder.state_dict(),
                    str(
                        Path(self.save_model_loc, self.experiment_name,
                             f'audio_encoder_epoch_best.pt')))
                torch.save(
                    self.mo_encoder.state_dict(),
                    str(
                        Path(self.save_model_loc, self.experiment_name,
                             f'mo_encoder_epoch_best.pt')))
Ejemplo n.º 29
0
def train_model_cvae(model,
                     data_loader,
                     epochs,
                     latent_dim,
                     categorical_dim,
                     device,
                     SNE_n_iter,
                     anneal,
                     log_interval,
                     temp,
                     hard,
                     temp_min=0.5,
                     ANNEAL_RATE=0.00003):
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    acc_bmm = ([])

    for epoch in range(1, epochs + 1):

        if anneal:
            epoch_lr = adjust_learning_rate(1e-3, optimizer, epoch)

        total_loss = 0
        total_psnr = 0

        total_z = ([])
        total_labels = []

        temp = temp

        for batch_idx, (data, label) in enumerate(data_loader):
            data = data.to(device)

            optimizer.zero_grad()

            recon_batch, qy, z = model(data, temp, hard, latent_dim,
                                       categorical_dim)
            loss = loss_function(recon_batch, data, qy, categorical_dim)

            z = z.detach().numpy().reshape(
                (len(data), latent_dim, categorical_dim))
            z = np.argmax(z, axis=2)

            psnr = PSNR(recon_batch, data.view(-1, 784), 1.0)

            loss.backward()
            total_loss += loss.item() * len(data)
            total_psnr += psnr.item() * len(data)
            optimizer.step()

            if batch_idx % 100 == 1:
                temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx),
                                  temp_min)

            total_z.append(z)
            total_labels = np.concatenate(
                (total_labels, label.detach().numpy()))

        total_z = np.concatenate(total_z)
        bmm_z_pred = cluster_sample(total_z, total_labels, z_type=True)
        acc_h2 = cluster_acc(bmm_z_pred, total_labels)

        acc_bmm = np.append(acc_bmm, acc_h2)

        print(
            '====> Epoch: {} lr:{} Average loss: {:.4f} Average psnr: {:.4f}'.
            format(epoch, epoch_lr, total_loss / len(data_loader.dataset),
                   total_psnr / len(data_loader.dataset)))

        # if epoch % log_interval == 0:
        # 	visualization(total_z, bmm_z_pred, total_labels, SNE_n_iter, epoch, 3)

    M = 64 * latent_dim
    np_y = np.zeros((M, categorical_dim), dtype=np.float32)
    np_y[range(M), np.random.choice(categorical_dim, M)] = 1
    np_y = np.reshape(np_y, [M // latent_dim, latent_dim, categorical_dim])
    sample = torch.from_numpy(np_y).view(M // latent_dim,
                                         latent_dim * categorical_dim)
    # if args.cuda:
    #   sample = sample.cuda()
    sample = model.decode(sample).cpu()
    save_image(sample.data.view(M // latent_dim, 1, 28, 28),
               './sample_' + 'cvae' + str(epoch) + '.png')

    return acc_bmm
Ejemplo n.º 30
0
Archivo: VAE2.py Proyecto: amimai/NNets
    for epoch in range(1, args.epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 3).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')

    import numpy as np

    with torch.no_grad():
        sp_dat = []
        for i, (data, targets) in enumerate(test_loader):
            data = data.to(device)
            mu, logvar = model.encode(data.view(-1, 784))
            rep = model.reparameterize(mu, logvar)
            sp_dat.append((rep, targets))

    x = []
    y = []
    z = []
    c = []
    for i in range(5):
        data = sp_dat[i]  # [tensors,targets]
        e_val = data[0].numpy()
        truth = data[1].numpy()

        for i in range(len(e_val)):
            w, e, r = e_val[i]
            c.append(truth[i])