Beispiel #1
0
            ", AE clf loss: %.8f" % float(AE_clf_loss),
            ", clf loss: %.8f" % float(clf_loss),
            ", clf class loss: %.8f" % float(clf_class_loss),
            ", clf accuracy RNA: %.4f" % float(n_rna_correct / n_rna_total),
            ", clf accuracy ATAC: %.4f" % float(n_atac_correct / n_atac_total),
            file=f,
        )

    # save model
    if epoch % args.save_freq == 0:
        torch.save(
            netRNA.cpu().state_dict(),
            os.path.join(args.save_dir, "netRNA_%s.pth" % epoch),
        )
        torch.save(
            netImage.cpu().state_dict(),
            os.path.join(args.save_dir, "netImage_%s.pth" % epoch),
        )
        torch.save(
            netClf.cpu().state_dict(),
            os.path.join(args.save_dir, "netClf_%s.pth" % epoch),
        )
        if args.conditional:
            torch.save(
                netCondClf.cpu().state_dict(),
                os.path.join(args.save_dir, "netCondClf_%s.pth" % epoch),
            )

    if args.use_gpu:
        netRNA.cuda()
        netClf.cuda()
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hidden",
                        '-hid',
                        type=int,
                        default=768,
                        help="hidden state dimension")
    parser.add_argument('--epochs',
                        '-e',
                        type=int,
                        default=5,
                        help="number of epochs")
    parser.add_argument('--learning_rate',
                        '-lr',
                        type=float,
                        default=1e-4,
                        help="learning rate")
    parser.add_argument('--grudim',
                        '-gd',
                        type=int,
                        default=1024,
                        help='dimension for gru layer')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=64,
                        help='input batch size for training')
    parser.add_argument('--name',
                        '-n',
                        type=str,
                        default='embedded',
                        help='tensorboard visual name')
    parser.add_argument('--decay',
                        '-d',
                        type=float,
                        default=-1,
                        help='learning rate decay: Gamma')
    parser.add_argument('--beta', type=float, default=0.1, help='beta for kld')
    parser.add_argument('--data',
                        type=int,
                        default=1000,
                        help='how many pieces of music to use')

    args = parser.parse_args()

    hidden_dim = args.hidden
    epochs = args.epochs
    gru_dim = args.grudim
    learning_rate = args.learning_rate
    batch_size = args.batch_size
    decay = args.decay
    beta = args.beta
    data_num = args.data

    folder_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d" % (
        hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay, beta,
        data_num)

    writer = SummaryWriter('../logs/{}'.format(folder_name))

    # load data
    file_list = find('*.npy', data_dir)
    f = np.load(data_dir + file_list[0])
    note_dim = f.shape[1]

    model = VAE(note_dim, gru_dim, hidden_dim, batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    if decay > 0:
        scheduler = MinExponentialLR(optimizer, gamma=decay, minimum=1e-5)
    step = 0

    if torch.cuda.is_available():
        print('Using: ',
              torch.cuda.get_device_name(torch.cuda.current_device()))
        model.cuda()
    else:
        print('CPU mode')

    for epoch in range(1, epochs):
        print("#" * 5, epoch, "#" * 5)
        batch_data = []
        batch_num = 0
        max_len = 0
        for i in range(len(file_list)):
            if i != 0 and i % batch_size == 0 or i == len(
                    file_list) - 1 or i == data_num:
                # create a batch by zero padding
                print("#" * 5, "batch", batch_num)
                if (i == len(file_list) - 1):
                    batch_size = len(file_list) % batch_size
                seq_lengths = LongTensor(list(map(len, batch_data)))
                print(seq_lengths.size())
                max_len = torch.max(seq_lengths).item()
                print("max_len:", max_len)
                batch = np.zeros((max_len, batch_size, note_dim))
                for j in range(len(batch_data)):
                    batch[:batch_data[j].shape[0], j, :] = batch_data[j]
                batch = torch.from_numpy(batch)
                seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
                batch = batch[:, perm_idx, :]
                step = train(model, batch, seq_lengths, step, optimizer, beta,
                             writer)
                # reset
                max_len = 0
                batch_data = []
                if decay > 0:
                    scheduler.step()
                batch_num += 1
            data = np.load(data_dir + file_list[i])
            batch_data.append(data)
            if i == data_num:
                break

        print("# saving params")
        param_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d_epoch%d" % (
            hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay,
            beta, data_num, epoch)
        save_path = '../params/{}.pt'.format(param_name)
        if not os.path.exists('params') or not os.path.isdir('params'):
            os.mkdir('params')
        if torch.cuda.is_available():
            torch.save(model.cpu().state_dict(), save_path)
            model.cuda()
        else:
            torch.save(model.state_dict(), save_path)
        print('# Model saved!')

    writer.close()
Beispiel #3
0
    distribution_1 = Normal(dis1m, dis1s)
    distribution_2 = Normal(dis2m, dis2s)
    loss = loss_function(recon,
                         recon_rhythm,
                         target_tensor,
                         rhythm_target,
                         distribution_1,
                         distribution_2,
                         step,
                         beta=args['beta'])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()
    step += 1
    print('batch loss: {:.5f}'.format(loss.item()))
    writer.add_scalar('batch_loss', loss.item(), step)
    if args['decay'] > 0:
        scheduler.step()
    dl.shuffle_samples()
    return step


while dl.get_n_epoch() < args['n_epochs']:
    step = train(step)
    if dl.get_n_epoch() != pre_epoch:
        pre_epoch = dl.get_n_epoch()
        torch.save(model.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            model.cuda()
        print('Model saved!')
              epoch,
              ', rna recon loss: %.8f' % float(recon_rna_loss),
              ', image recon loss: %.8f' % float(recon_image_loss),
              ', AE clf loss: %.8f' % float(AE_clf_loss),
              ', clf loss: %.8f' % float(clf_loss),
              ', clf class loss: %.8f' % float(clf_class_loss),
              ', clf accuracy RNA: %.4f' % float(n_rna_correct / n_rna_total),
              ', clf accuracy ATAC: %.4f' %
              float(n_atac_correct / n_atac_total),
              file=f)

    # save model
    if epoch % args.save_freq == 0:
        torch.save(netRNA.cpu().state_dict(),
                   os.path.join(args.save_dir, "netRNA_%s.pth" % epoch))
        torch.save(netImage.cpu().state_dict(),
                   os.path.join(args.save_dir, "netImage_%s.pth" % epoch))
        torch.save(netClf.cpu().state_dict(),
                   os.path.join(args.save_dir, "netClf_%s.pth" % epoch))
        if args.conditional:
            torch.save(
                netCondClf.cpu().state_dict(),
                os.path.join(args.save_dir, "netCondClf_%s.pth" % epoch))

    if args.use_gpu:
        netRNA.cuda()
        netClf.cuda()
        netImage.cuda()
        if args.conditional:
            netCondClf.cuda()