Exemplo n.º 1
0
def validation(model, args, lr, epoch, device):
    dataloader, dataset = make_loader(args.cv_list,
                                      args.batch_size,
                                      num_workers=args.num_threads,
                                      processer=Processer(
                                          win_len=args.win_len,
                                          win_inc=args.win_inc,
                                          left_context=args.left_context,
                                          right_context=args.right_context,
                                          fft_len=args.fft_len,
                                          window_type=args.win_type))
    model.eval()
    loss_total = 0.0
    num_batch = len(dataloader)
    stime = time.time()
    with torch.no_grad():
        for idx, data in enumerate(dataloader):
            inputs, labels, lengths = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            lengths = lengths
            outputs, _ = data_parallel(model, (inputs, lengths))
            loss = model.loss(outputs, labels, lengths)
            loss_total += loss.data.cpu()
            del loss, data, inputs, labels, lengths, _, outputs
        etime = time.time()
        eplashed = (etime - stime) / num_batch
        loss_total_avg = loss_total / num_batch

    print('CROSSVAL AVG.LOSS | Epoch {:3d}/{:3d} '
          '| lr {:.6e} | {:2.3f}s/batch| time {:2.1f}mins '
          '| loss {:2.8f}'.format(epoch + 1, args.max_epoch, lr, eplashed,
                                  (etime - stime) / 60.0, loss_total_avg))
    sys.stdout.flush()
    return loss_total_avg
Exemplo n.º 2
0
def train(model, args, device, writer):
    print('preparing data...')
    dataloader, dataset = make_loader(
        args.tr_list,
        args.batch_size,
        num_workers=args.num_threads,
        processer=Processer(
            win_len=args.win_len,
            win_inc=args.win_inc,
            left_context=args.left_context,
            right_context=args.right_context,
            fft_len=args.fft_len,
            window_type=args.win_type))
    print_freq = 100
    num_batch = len(dataloader)
    params = model.get_params(args.weight_decay)
    optimizer = optim.Adam(params, lr=args.learn_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', factor=0.5, patience=1, verbose=True)
    
    if args.retrain:
        start_epoch, step = reload_model(model, optimizer, args.exp_dir,
                                         args.use_cuda)
    else:
        start_epoch, step = 0, 0
    print('---------PRERUN-----------')
    lr = get_learning_rate(optimizer)
    print('(Initialization)')
    val_loss = validation(model, args, lr, -1, device)
    writer.add_scalar('Loss/Train', val_loss, step)
    writer.add_scalar('Loss/Cross-Validation', val_loss, step)

    for epoch in range(start_epoch, args.max_epoch):
        torch.manual_seed(args.seed + epoch)
        if args.use_cuda:
            torch.cuda.manual_seed(args.seed + epoch)
        model.train()
        loss_total = 0.0
        loss_print = 0.0
        stime = time.time()
        lr = get_learning_rate(optimizer)
        for idx, data in enumerate(dataloader):
            inputs, labels, lengths = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            lengths = lengths
            
            model.zero_grad()
            outputs, _ = data_parallel(model, (inputs, lengths))
            
            loss = model.loss(outputs, labels, lengths)
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()
            step += 1
            loss_total += loss.data.cpu()
            loss_print += loss.data.cpu()
            
            del lengths, outputs, labels, inputs, loss, _
            if (idx+1) % 3000 == 0:
                save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir)
            if (idx + 1) % print_freq == 0:
                eplashed = time.time() - stime
                speed_avg = eplashed / (idx+1)
                loss_print_avg = loss_print / print_freq
                print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'
                      '{:2.3f}s/batches | loss {:2.6f}'.format(
                          epoch, args.max_epoch, idx + 1, num_batch, lr,
                          speed_avg, loss_print_avg))
                sys.stdout.flush()
                writer.add_scalar('Loss/Train', loss_print_avg, step)
                loss_print = 0.0
        eplashed = time.time() - stime
        loss_total_avg = loss_total / num_batch
        print(
            'Training AVG.LOSS |'
            ' Epoch {:3d}/{:3d} | lr {:1.4e} |'
            ' {:2.3f}s/batch | time {:3.2f}mins |'
            ' loss {:2.6f}'.format(
                                    epoch + 1,
                                    args.max_epoch,
                                    lr,
                                    eplashed/num_batch,
                                    eplashed/60.0,
                                    loss_total_avg.item()))
        val_loss = validation(model, args, lr, epoch, device)
        writer.add_scalar('Loss/Cross-Validation', val_loss, step)
        writer.add_scalar('learn_rate', lr, step) 
        if val_loss > scheduler.best:
            print('Rejected !!! The best is {:2.6f}'.format(scheduler.best))
        else:
            save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir)
        scheduler.step(val_loss)
        sys.stdout.flush()
        stime = time.time()