Esempio n. 1
0
def load_all_model(root_dir, device=0):
    model_corpus = []
    for i in range(17):
        config_file = os.path.join(root_dir, str(i), "config.json")
        with open(config_file, 'r') as fin:
            config = json.load(fin)
        args = argparse.Namespace(**config)
        item = []
        for j in range(args.model_num):
            if args.model_type == 'lstm':
                model = models.LSTMModel(args)
            elif args.model_type == 'conv':
                model = models.ConvModel(args)
            elif args.model_type == 'char':
                model = models.CharCNNModel(args)
            elif args.model_type == 'base':
                model = models.BaseModel(args)
            else:
                raise NotImplementedError
            model_path = os.path.join(
                args.checkpoint_path, str(i),
                "%s_%s" % (args.model_type, args.type_suffix),
                "model_%d.pth" % j)
            if not os.path.isfile(model_path):
                print("No model to test")
                exit(1)
            model.load_state_dict(torch.load(model_path))
            model = model.cuda(device)
            model.eval()
            item.append(model)
        model_corpus.append(item)
    return model_corpus
def build_model(model_type, model_part, learned, seq_len, feature_num, kmer, clayer_num, filters, layer_sizes, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout=0.1):
    
    if model_type == "ConvModel":
        convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)]
        model = models.ConvModel(model_part, seq_len, feature_num, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False)
    elif model_type == "SpannyConvModel":
        convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)]
        global_kernel = {"kernel_size": kmer, "filters": filters, "activation": "ReLU"}
        
        model = models.SpannyConvModel(model_part, seq_len, feature_num, global_kernel, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False)
    elif model_type == "MHCflurry":
        locally_connected_layers = [{"kernel_size": kmer, "filters": filters, "activation": "Tanh"} for _ in range(clayer_num)]
        model = models.MHCflurry(model_part, seq_len, feature_num, locally_connected_layers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout)
    elif model_type == "Transformer":
        # d_model : 
        model = models.Transformer(model_part, seq_len, feature_num, feature_num*2, filters, int(feature_num/filters), int(feature_num/filters), layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim)
    else:
        raise ValueError("Unsupported model type : "+model_type)
   
    if torch.cuda.is_available():
        model.cuda()
    return model
def train(logdir, device, model_name, iterations, resume_iteration,
          checkpoint_interval, load_mode, num_workers, batch_size,
          sequence_length, model_complexity, learning_rate,
          learning_rate_decay_steps, learning_rate_decay_rate,
          clip_gradient_norm, validation_interval, print_interval, debug):
    default_device = 'cpu' if len(device) == 0 else 'cuda:{}'.format(device[0])
    print("Traiing on: {}".format(default_device))

    logdir += model_name

    os.makedirs(logdir)
    writer = SummaryWriter(logdir)
    # valid_writer = SummaryWriter(logdir + '/valid')

    print("Running a {}-model".format(model_name))
    if model_name == "SegmentConv":
        dataset_class = SegmentExcerptDataset
    elif model_name == "OnsetConv":
        dataset_class = OnsetExcerptDataset

    dataset = dataset_class(set='train')
    validation_dataset = dataset_class(set='test')

    loader = DataLoader(dataset,
                        batch_size,
                        shuffle=True,
                        num_workers=num_workers)
    validation_loader = DataLoader(validation_dataset,
                                   batch_size,
                                   shuffle=False,
                                   num_workers=num_workers)

    model = models.ConvModel(device=default_device)

    if resume_iteration is None:
        model = model.to(default_device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        resume_iteration = 0
    else:
        # model_state_path = os.path.join(logdir, 'model-{:d}.pt' % resume_iteration)
        model_state_path = os.path.join(logdir, 'model-checkpoint.pt')
        checkpoint = torch.load(model_state_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        if len(device) == 1:
            model = model.to(default_device)
        elif len(device) >= 2:
            model = torch.nn.DataParallel(model,
                                          device_ids=device).to(default_device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    for layer in model.modules():
        if isinstance(layer, torch.nn.Conv2d):
            torch.nn.init.xavier_uniform_(layer.weight)

    scheduler = StepLR(optimizer,
                       step_size=learning_rate_decay_steps,
                       gamma=learning_rate_decay_rate)

    i = 1
    total_loss = 0
    best_validation_loss = 1
    for batch in cycle(loader):
        #for batch, (x, y) in enumerate(loader):
        #print(batch)
        #print((x,y))
        # print(batch[0])
        optimizer.zero_grad()
        scheduler.step()
        loss = 0
        pred, loss = models.run_on_batch(model,
                                         batch[0],
                                         batch[1],
                                         device=default_device)
        loss.backward()
        if clip_gradient_norm:
            clip_grad_norm_(model.parameters(), clip_gradient_norm)

        optimizer.step()

        # print("loss: {:.3f}".format(loss))
        # loop.set_postfix_str("loss: {:.3f}".format(loss))
        total_loss += loss.item()

        if i % print_interval == 0:
            print("total_train_loss: {:.3f} minibatch: {:6d}/{:6d}".format(
                total_loss / print_interval, i, len(loader)))
            writer.add_scalar('data/loss',
                              total_loss / print_interval,
                              global_step=i)
            total_loss = 0

        if i % checkpoint_interval == 0:
            state_dict = model.module.state_dict(
            ) if len(device) >= 2 else model.state_dict()
            torch.save(
                {
                    'model_state_dict': state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'model_name': model_name
                },
                # os.path.join(logdir, 'model-{:d}.pt'.format(i)))
                os.path.join(logdir, 'model-checkpoint.pt'))

        if i % validation_interval == 0:
            model.eval()
            with torch.no_grad():
                total_validation_loss = 0
                counter = 0
                for batch in validation_loader:
                    pred, loss = models.run_on_batch(model,
                                                     batch[0],
                                                     batch[1],
                                                     device=default_device)
                    total_validation_loss += loss.item()

                total_validation_loss /= len(validation_dataset)
                print("total_valid_loss: {:.3f} minibatch: {:6d}".format(
                    total_validation_loss, i))
                writer.add_scalar('data/valid_loss',
                                  total_validation_loss,
                                  global_step=i)

                if total_validation_loss < best_validation_loss:
                    best_validation_loss = total_validation_loss
                    torch.save(
                        {
                            'model_state_dict': state_dict,
                            'optimizer_state_dict': optimizer.state_dict(),
                            'model_name': model_name
                        }, os.path.join(logdir, 'model-best-val-loss'))

            model.train()

        i += 1
Esempio n. 4
0
def train(args, model_id, tb):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    train_data = MedicalEasyEnsembleDataloader(args.train_data, args.class_id,
                                               args.batch_size, True,
                                               args.num_workers)
    val_data = MedicalEasyEnsembleDataloader(args.val_data, args.class_id,
                                             args.batch_size, False,
                                             args.num_workers)
    if os.path.exists(args.w2v_file):
        embedding = utils.load_embedding(args.w2v_file,
                                         vocab_size=args.vocab_size,
                                         embedding_size=args.embedding_size)
    else:
        embedding = None
    if args.model_type == 'lstm':
        model = models.LSTMModel(args, embedding)
    elif args.model_type == 'conv':
        model = models.ConvModel(args, embedding)
    elif args.model_type == 'char':
        model = models.CharCNNModel(args, embedding)
    elif args.model_type == 'base':
        model = models.BaseModel(args, embedding)
    else:
        raise NotImplementedError
    if os.path.isfile(
            os.path.join(args.checkpoint_path, str(args.class_id),
                         "%s_%s" % (args.model_type, args.type_suffix),
                         "model_%d.pth" % model_id)):
        print("Load %d class %s type %dth model from previous step" %
              (args.class_id, args.model_type, model_id))
        model.load_state_dict(
            torch.load(
                os.path.join(args.checkpoint_path, str(args.class_id),
                             "%s_%s" % (args.model_type, args.type_suffix),
                             "model_%d.pth" % model_id)))
    iteration = 0
    model = model.cuda(args.device)
    model.train()
    optimizer = utils.build_optimizer(args, model)
    loss_func = MultiBceLoss()
    cur_worse = 1000
    bad_times = 0
    for epoch in range(args.epochs):
        if epoch >= args.start_epoch:
            factor = (epoch - args.start_epoch) // args.decay_every
            decay_factor = args.decay_rate**factor
            current_lr = args.lr * decay_factor
            utils.set_lr(optimizer, current_lr)
        # if epoch != 0 and epoch % args.sample_every == 0:
        #     train_data.re_sample()
        for i, data in enumerate(train_data):
            tmp = [
                _.cuda(args.device) if isinstance(_, torch.Tensor) else _
                for _ in data
            ]
            report_ids, sentence_ids, sentence_lengths, output_vec = tmp
            optimizer.zero_grad()
            loss = loss_func(model(sentence_ids, sentence_lengths), output_vec)
            loss.backward()
            train_loss = loss.item()
            optimizer.step()
            iteration += 1
            if iteration % args.print_every == 0:
                print("iter %d epoch %d loss: %.3f" %
                      (iteration, epoch, train_loss))

            if iteration % args.save_every == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(args.checkpoint_path, str(args.class_id),
                                 "%s_%s" % (args.model_type, args.type_suffix),
                                 "model_%d.pth" % model_id))
                with open(os.path.join(args.checkpoint_path,
                                       str(args.class_id), "config.json"),
                          'w',
                          encoding='utf-8') as config_f:
                    json.dump(vars(args), config_f, indent=2)
                with open(os.path.join(
                        args.checkpoint_path, str(args.class_id),
                        "%s_%s" % (args.model_type, args.type_suffix),
                        "config.json"),
                          'w',
                          encoding='utf-8') as config_f:
                    json.dump(vars(args), config_f, indent=2)
            if iteration % args.val_every == 0:
                val_loss = eval_model(model, loss_func, val_data, epoch)
                tb.add_scalar("model_%d val_loss" % model_id, val_loss,
                              iteration)
                if val_loss > cur_worse:
                    print("Bad Time Appear")
                    cur_worse = val_loss
                    bad_times += 1
                else:
                    cur_worse = val_loss
                    bad_times = 0
                if bad_times > args.patient:
                    print('Early Stop !!!!')
                    return
            if iteration % args.loss_log_every == 0:
                tb.add_scalar("model_%d train_loss" % model_id, loss.item(),
                              iteration)

    print("The train finished")
            plt.fill_between(frmtimes, 0, 0.5, where=segframes_gt >
                             0, facecolor='green', alpha=0.7, label='ground truth')
            plt.fill_between(frmtimes, -0.5, 0, where=segframes_est >
                             0, facecolor='orange', alpha=0.7, label='estimation')
            # plt.title("Pedal segment detection of {}".format(filename))
            plt.legend()
            # plt.xlim([left,right])
            # plt.show()
            plt.savefig("test")

            return segframes_est


if __name__ == '__main__':
    # Load some test models
    onset_model = models.ConvModel()
    segment_model = models.ConvModel()

    onset_model.load_state_dict(torch.load(
        "test_models/onset_conv.pt")["model_state_dict"])
    segment_model.load_state_dict(torch.load(
        "test_models/segment_conv.pt")["model_state_dict"])

    file_name = "2011/MIDI-Unprocessed_22_R1_2011_MID--AUDIO_R1-D8_12_Track12_wav"

    test_audio_file = ORIGINAL_DATA_PATH + file_name + ".flac"
    test_midi_file = ORIGINAL_DATA_PATH + file_name + ".midi"

    # Load any audio file to test
    audio, sr = librosa.load(test_audio_file, sr=SAMPLING_RATE)
    print(audio.shape[0])