Example #1
0
def train(args,model,processor):
    train_dataset = load_and_cache_examples(args, processor, data_type='train')
    train_loader = DatasetLoader(data=train_dataset, batch_size=args.batch_size,
                                 shuffle=False, seed=args.seed, sort=True,
                                 vocab = processor.vocab,label2id = args.label2id)
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters, lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3,
                                  verbose=1, epsilon=1e-4, cooldown=0, min_lr=0, eps=1e-8)
    best_f1 = 0
    for epoch in range(1, 1 + args.epochs):
        print(f"Epoch {epoch}/{args.epochs}")
        pbar = ProgressBar(n_total=len(train_loader), desc='Training')
        train_loss = AverageMeter()
        model.train()
        assert model.training
        for step, batch in enumerate(train_loader):
            input_ids, input_mask, input_tags, input_lens = batch
            input_ids = input_ids.to(args.device)
            input_mask = input_mask.to(args.device)
            input_tags = input_tags.to(args.device)
            features, loss = model.forward_loss(input_ids, input_mask, input_lens, input_tags)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            pbar(step=step, info={'loss': loss.item()})
            train_loss.update(loss.item(), n=1)
        print(" ")
        train_log = {'loss': train_loss.avg}
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
        eval_log, class_info = evaluate(args,model,processor)
        logs = dict(train_log, **eval_log)
        show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key, value in logs.items()])
        logger.info(show_info)
        scheduler.epoch_step(logs['eval_f1'], epoch)
        if logs['eval_f1'] > best_f1:
            logger.info(f"\nEpoch {epoch}: eval_f1 improved from {best_f1} to {logs['eval_f1']}")
            logger.info("save model to disk.")
            best_f1 = logs['eval_f1']
            if isinstance(model, nn.DataParallel):
                model_stat_dict = model.module.state_dict()
            else:
                model_stat_dict = model.state_dict()
            state = {'epoch': epoch, 'arch': args.arch, 'state_dict': model_stat_dict}
            model_path = args.output_dir / 'best-model.bin'
            torch.save(state, str(model_path))
            print("Eval Entity Score: ")
            for key, value in class_info.items():
                info = f"Subject: {key} - Acc: {value['acc']} - Recall: {value['recall']} - F1: {value['f1']}"
                logger.info(info)
Example #2
0
def initialize(mode, is_gpu, dir_data, di_set_transform, ext_img,
               n_img_per_batch, n_worker):

    if 'TORCHVISION_MEMORY' == mode:
        trainloader, testloader, li_class = make_dataloader_torchvison_memory(
            dir_data, di_set_transform, n_img_per_batch, n_worker)
    elif 'TORCHVISION_IMAGEFOLDER' == mode:
        trainloader, testloader, li_class = make_dataloader_torchvison_imagefolder(
            dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)
    elif 'CUSTOM_MEMORY' == mode:
        trainloader, testloader, li_class = make_dataloader_custom_memory(
            dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)
    elif 'CUSTOM_FILE' == mode:
        trainloader, testloader, li_class = make_dataloader_custom_file(
            dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)
    else:
        trainloader, testloader, li_class = make_dataloader_custom_tensordataset(
            dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)

    #net = Net().cuda()
    net = Net()
    #t1 = net.cuda()
    criterion = nn.CrossEntropyLoss()
    if is_gpu:
        net.cuda()
        criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  verbose=1,
                                  patience=8,
                                  epsilon=0.00001,
                                  min_lr=0.000001)  # set up scheduler

    return trainloader, testloader, net, criterion, optimizer, scheduler, li_class
Example #3
0
def initialize(
        dev,
        dir_data,
        size_img,  #di_set_transform,
        ext_img,
        n_img_per_batch,
        n_worker,
        li_idx_sample_ratio=None
    #, n_class
):


    trainloader, testloader, li_idx_sample, li_fn_sample =\
        make_dataloader_custom_file(
            dir_data, size_img, #di_set_transform,
            ext_img, n_img_per_batch, n_worker, li_idx_sample_ratio)

    #net = Net().cuda()
    #net = Net(n_class, n_img_per_batch)
    #net = Net(n_img_per_batch)
    net = Network((64, 64))
    #t1 = net.cuda()
    #criterion = nn.CrossEntropyLoss()
    criterion = nn.MSELoss()
    #print('is_gpu :', is_gpu);  exit(0);
    #if is_gpu:
    #    net.cuda()
    #    criterion.cuda()
    #print(net.li_conv_block[0].layer_in_a_row[0].weight.type())
    #print(net)
    #print(net.conv_block_series[0].layer_in_a_row[0].weight.type());  #exit(0);
    net = net.to(dev)
    criterion.to(dev)
    #print(net);  exit(0);
    #print(net.conv_block_series[0].layer_in_a_row[0].weight.type());  exit(0);

    #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    #optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    optimizer = optim.Adam(net.parameters())
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  verbose=1,
                                  patience=8,
                                  epsilon=0.00001,
                                  min_lr=0.000001)  # set up scheduler

    return trainloader, testloader, net, criterion, optimizer, scheduler, li_idx_sample, li_fn_sample
Example #4
0
def initialize(is_gpu, dir_data, di_set_transform, ext_img, n_img_per_batch,
               n_worker):

    trainloader, testloader, li_class = make_dataloader_custom_file(
        dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)

    #net = Net().cuda()
    net = Net_gap()
    #t1 = net.cuda()
    criterion = nn.CrossEntropyLoss()
    if is_gpu:
        net.cuda()
        criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  verbose=1,
                                  patience=8,
                                  epsilon=0.00001,
                                  min_lr=0.000001)  # set up scheduler

    return trainloader, testloader, net, criterion, optimizer, scheduler, li_class
Example #5
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
Example #6
0
def train_model(args):
    print(args)
    print("generating config")
    config = Config(
        input_dim=args.input_dim,
        dropout=args.dropout,
        highway=args.highway,
        nn_layers=args.nn_layers,
    )
    model_name = ".".join(
        (args.model_file, str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch),
         "learning_rate", str(args.lr) + "-" + str(args.lr_sch), "bsz",
         str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    log_name = ".".join(
        ("log_bert/model", str(args.rl_baseline_method), args.sampling_method,
         "gamma", str(args.gamma), "beta", str(args.beta), "batch",
         str(args.train_batch), "lr", str(args.lr) + "-" + str(args.lr_sch),
         "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0],
         args.eval_data, "input_dim", str(config.input_dim), "max_num",
         str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout",
         str(args.dropout) + "-" + str(args.clip_grad), "highway",
         str(args.highway), "nn-" + str(args.nn_layers), 'ans'))

    print("initialising data loader and RL learner")
    data_loader = PickleReader(args.data_dir)
    data = args.data_dir.split('/')[0]
    num_data = 0
    if data == "wiki_qa":
        num_data = 873
    elif data == "trec_qa":
        num_data = 1229
    else:
        assert (1 == 2)
    # init statistics
    reward_list = []
    loss_list = []
    best_eval_reward = 0.
    model_save_name = model_name

    bandit = ContextualBandit(b=args.batch_size,
                              rl_baseline_method=args.rl_baseline_method,
                              sample_method=args.sampling_method)

    print("Loaded the Bandit")

    bert_cb = model2.BERT_CB(config)

    print("Loaded the model")

    bert_cb.cuda()
    vocab = "vocab"

    if args.load_ext:
        model_name = args.model_file
        print("loading existing model%s" % model_name)
        bert_cb = torch.load(model_name,
                             map_location=lambda storage, loc: storage)
        bert_cb.cuda()
        model_save_name = model_name
        log_name = "/".join(("log_bert", model_name.split("/")[1]))
        print("finish loading and evaluate model %s" % model_name)
        # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test")
        best_eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                   args.eval_data)[0]
    logging.basicConfig(filename='%s.log' % log_name,
                        level=logging.DEBUG,
                        format='%(asctime)s %(levelname)-10s %(message)s')
    # Loss and Optimizer
    optimizer_ans = torch.optim.Adam([
        param for param in bert_cb.parameters() if param.requires_grad == True
    ],
                                     lr=args.lr,
                                     betas=(args.beta, 0.999),
                                     weight_decay=1e-6)
    if args.lr_sch == 1:
        scheduler = ReduceLROnPlateau(optimizer_ans,
                                      'max',
                                      verbose=1,
                                      factor=0.9,
                                      patience=3,
                                      cooldown=3,
                                      min_lr=9e-5,
                                      epsilon=1e-6)
        if best_eval_reward:
            scheduler.step(best_eval_reward, 0)
            print("init_scheduler")
    elif args.lr_sch == 2:
        scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer_ans,
            args.lr,
            args.lr_2,
            step_size_up=3 * int(num_data / args.train_batch),
            step_size_down=3 * int(num_data / args.train_batch),
            mode='exp_range',
            gamma=0.98,
            cycle_momentum=False)
    print("starting training")
    start_time = time.time()
    n_step = 100
    gamma = args.gamma
    #vocab = "vocab"
    if num_data < 2000:

        n_val = int(num_data / (5 * args.train_batch))
    else:
        n_val = int(num_data / (7 * args.train_batch))
    with torch.autograd.set_detect_anomaly(True):
        for epoch in tqdm(range(args.epochs_ext), desc="epoch:"):
            train_iter = data_loader.chunked_data_reader(
                "train", data_quota=args.train_example_quota)  #-1
            step_in_epoch = 0
            for dataset in train_iter:
                for step, contexts in tqdm(
                        enumerate(
                            BatchDataLoader(dataset,
                                            batch_size=args.train_batch,
                                            shuffle=True))):
                    try:
                        bert_cb.train()
                        step_in_epoch += 1
                        loss = 0.
                        reward = 0.
                        for context in contexts:

                            # q_a = torch.autograd.Variable(torch.from_numpy(context.features)).cuda()
                            pre_processed, a_len, sorted_id = model2.bert_preprocess(
                                context.answers)
                            q_a = torch.autograd.Variable(
                                pre_processed.type(torch.float))
                            a_len = torch.autograd.Variable(a_len)

                            outputs = bert_cb(q_a, a_len)
                            context.labels = np.array(
                                context.labels)[sorted_id]

                            if args.prt_inf and np.random.randint(0, 100) == 0:
                                prt = True
                            else:
                                prt = False

                            loss_t, reward_t = bandit.train(
                                outputs,
                                context,
                                max_num_of_ans=args.max_num_of_ans,
                                reward_type=args.reward_type,
                                prt=prt)
                            #print(str(loss_t)+' '+str(len(a_len)))

                            #    loss_t = loss_t.view(-1)
                            true_labels = np.zeros(len(context.labels))
                            gold_labels = np.array(context.labels)
                            true_labels[gold_labels > 0] = 1.0
                            # ml_loss = F.binary_cross_entropy(outputs.view(-1),torch.tensor(true_labels).type(torch.float).cuda())
                            ml_loss = F.binary_cross_entropy(
                                outputs.view(-1),
                                torch.tensor(true_labels).type(
                                    torch.float).cuda())

                            loss_e = ((gamma * loss_t) +
                                      ((1 - gamma) * ml_loss))
                            loss_e.backward()
                            loss += loss_e.item()
                            reward += reward_t
                        loss = loss / args.train_batch
                        reward = reward / args.train_batch
                        if prt:
                            print('Probabilities: ',
                                  outputs.squeeze().data.cpu().numpy())
                            print('-' * 80)

                        reward_list.append(reward)
                        loss_list.append(loss)
                        #if isinstance(loss, Variable):
                        #    loss.backward()

                        if step % 1 == 0:
                            if args.clip_grad:
                                torch.nn.utils.clip_grad_norm_(
                                    bert_cb.parameters(),
                                    args.clip_grad)  # gradient clipping
                            optimizer_ans.step()
                            optimizer_ans.zero_grad()
                        if args.lr_sch == 2:
                            scheduler.step()
                        logging.info('Epoch %d Step %d Reward %.4f Loss %.4f' %
                                     (epoch, step_in_epoch, reward, loss))
                    except Exception as e:
                        print(e)
                        #print(loss)
                        #print(loss_e)
                        traceback.print_exc()

                    if (step_in_epoch) % n_step == 0 and step_in_epoch != 0:
                        logging.info('Epoch ' + str(epoch) + ' Step ' +
                                     str(step_in_epoch) + ' reward: ' +
                                     str(np.mean(reward_list)) + ' loss: ' +
                                     str(np.mean(loss_list)))
                        reward_list = []
                        loss_list = []

                    if (step_in_epoch) % n_val == 0 and step_in_epoch != 0:
                        print("doing evaluation")
                        bert_cb.eval()
                        eval_reward = evaluate.ext_model_eval(
                            bert_cb, vocab, args, args.eval_data)

                        if eval_reward[0] > best_eval_reward:
                            best_eval_reward = eval_reward[0]
                            print(
                                "saving model %s with eval_reward:" %
                                model_save_name, eval_reward)
                            logging.debug("saving model" +
                                          str(model_save_name) +
                                          "with eval_reward:" +
                                          str(eval_reward))
                            torch.save(bert_cb, model_name)
                        print('epoch ' + str(epoch) +
                              ' reward in validation: ' + str(eval_reward))
                        logging.debug('epoch ' + str(epoch) +
                                      ' reward in validation: ' +
                                      str(eval_reward))
                        logging.debug('time elapsed:' +
                                      str(time.time() - start_time))
            if args.lr_sch == 1:
                bert_cb.eval()
                eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args,
                                                      args.eval_data)
                scheduler.step(eval_reward[0], epoch)
    return bert_cb
Example #7
0
def main():
    args = parser.parse_args()

    if args.output:
        output_base = args.output
    else:
        output_base = './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"),
        args.model,
        args.gp,
        'f'+str(args.fold)])
    output_dir = get_outdir(output_base, 'train', exp_name)

    train_input_root = os.path.join(args.data)
    batch_size = args.batch_size
    num_epochs = args.epochs
    wav_size = (16000,)
    num_classes = len(dataset.get_labels())

    torch.manual_seed(args.seed)

    model = model_factory.create_model(
        args.model,
        in_chs=1,
        pretrained=args.pretrained,
        num_classes=num_classes,
        drop_rate=args.drop,
        global_pool=args.gp,
        checkpoint_path=args.initial_checkpoint)
    #model.reset_classifier(num_classes=num_classes)

    dataset_train = dataset.CommandsDataset(
        root=train_input_root,
        mode='train',
        fold=args.fold,
        wav_size=wav_size,
        format='spectrogram',
    )

    loader_train = data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=args.workers
    )

    dataset_eval = dataset.CommandsDataset(
        root=train_input_root,
        mode='validate',
        fold=args.fold,
        wav_size=wav_size,
        format='spectrogram',
    )

    loader_eval = data.DataLoader(
        dataset_eval,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=False,
        num_workers=args.workers
    )

    train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss()
    train_loss_fn = train_loss_fn.cuda()
    validate_loss_fn = validate_loss_fn.cuda()

    opt_params = list(model.parameters())
    if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            opt_params, lr=args.lr,
            momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    elif args.opt.lower() == 'adam':
        optimizer = optim.Adam(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'nadam':
        optimizer = nadam.Nadam(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'adadelta':
        optimizer = optim.Adadelta(
            opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
    elif args.opt.lower() == 'rmsprop':
        optimizer = optim.RMSprop(
            opt_params, lr=args.lr, alpha=0.9, eps=args.opt_eps,
            momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        assert False and "Invalid optimizer"
    del opt_params

    if not args.decay_epochs:
        print('No decay epoch set, using plateau scheduler.')
        lr_scheduler = ReduceLROnPlateau(optimizer, patience=10)
    else:
        lr_scheduler = None

    # optionally resume from a checkpoint
    start_epoch = 0 if args.start_epoch is None else args.start_epoch
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                if 'args' in checkpoint:
                    print(checkpoint['args'])
                new_state_dict = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    if k.startswith('module'):
                        name = k[7:] # remove `module.`
                    else:
                        name = k
                    new_state_dict[name] = v
                model.load_state_dict(new_state_dict)
                if 'optimizer' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer'])
                if 'loss' in checkpoint:
                    train_loss_fn.load_state_dict(checkpoint['loss'])
                print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
                start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
            else:
                model.load_state_dict(checkpoint)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(1)

    saver = CheckpointSaver(checkpoint_dir=output_dir)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of)
    if not args.resume and args.ft_epochs > 0.:
        if isinstance(model, torch.nn.DataParallel):
            classifier_params = model.module.get_classifier().parameters()
        else:
            classifier_params = model.get_classifier().parameters()
        if args.opt.lower() == 'adam':
            finetune_optimizer = optim.Adam(
                classifier_params,
                lr=args.ft_lr, weight_decay=args.weight_decay)
        else:
            finetune_optimizer = optim.SGD(
                classifier_params,
                lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)

        finetune_epochs_int = int(np.ceil(args.ft_epochs))
        finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train)))
        print(finetune_epochs_int, finetune_final_batches)
        for fepoch in range(0, finetune_epochs_int):
            if fepoch == finetune_epochs_int - 1 and finetune_final_batches:
                batch_limit = finetune_final_batches
            else:
                batch_limit = 0
            train_epoch(
                fepoch, model, loader_train, finetune_optimizer, train_loss_fn, args,
                output_dir=output_dir, batch_limit=batch_limit)

    best_loss = None
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.decay_epochs:
                adjust_learning_rate(
                    optimizer, epoch, initial_lr=args.lr,
                    decay_rate=args.decay_rate, decay_epochs=args.decay_epochs)

            train_metrics = train_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                saver=saver, output_dir=output_dir)

            # save a recovery in case validation blows up
            saver.save_recovery({
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': train_loss_fn.state_dict(),
                'args': args,
                'gp': args.gp,
                },
                epoch=epoch + 1,
                batch_idx=0)

            step = epoch * len(loader_train)
            eval_metrics = validate(
                step, model, loader_eval, validate_loss_fn, args,
                output_dir=output_dir)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_metrics['eval_loss'])

            rowd = OrderedDict(epoch=epoch)
            rowd.update(train_metrics)
            rowd.update(eval_metrics)
            with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
                dw = csv.DictWriter(cf, fieldnames=rowd.keys())
                if best_loss is None:  # first iteration (epoch == 1 can't be used)
                    dw.writeheader()
                dw.writerow(rowd)

            # save proper checkpoint with eval metric
            best_loss = saver.save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': args,
                'gp': args.gp,
                },
                epoch=epoch + 1,
                metric=eval_metrics['eval_loss'])

    except KeyboardInterrupt:
        pass
    print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
Example #8
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = imagetransforms.Compose([
        imagetransforms.Scale(new_h=args.line_height),
        imagetransforms.InvertBlackWhite(),
        imagetransforms.ToTensor(),
    ])

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    train_dataset = OcrDataset(args.datadir, "train", line_img_transforms)
    validation_dataset = OcrDataset(args.datadir, "validation",
                                    line_img_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    validation_dataloader = DataLoader(validation_dataset,
                                       args.batch_size,
                                       num_workers=0,
                                       sampler=GroupedSampler(
                                           validation_dataset, rand=False),
                                       collate_fn=SortByWidthCollater,
                                       pin_memory=False,
                                       drop_last=False)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
    else:
        model = CnnOcrModel(num_in_channels=1,
                            input_line_height=args.line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Do something with loss, running average, plot to some backend server, etc

            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)
                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
Example #9
0
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
#'''

net = Net().cuda()
#t1 = net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1) # set up scheduler

n_image_total = 0
running_loss = 0.0
is_lr_just_decayed = False
shall_stop = False
for epoch in range(n_epoch):  # loop over the dataset multiple times
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()
Example #10
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB18.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(
        train_source,
        args.batch_size,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    valid_iter = data_iterator(
        valid_source,
        1,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training.
    default_batch_size = 16
    train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * train_scale_factor
    args.lr = args.lr * train_scale_factor

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    # clear cache memory
    ext.clear_memory_cache()

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    # Get X-UMX/UMX computation graph and variables as namedtuple
    model = get_model(args, scaler_mean, scaler_std, max_bin=max_bin)

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # AverageMeter for mean loss calculation over the epoch
    losses = utils.AverageMeter()

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses.reset()
        for batch in range(max_iter):
            model.mixture_audio.d, model.target_audio.d = train_iter.next()
            solver.zero_grad()
            model.loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                model.loss.backward(clear_buffer=True,
                                    communicator_callbacks=all_reduce_callback)
            else:
                model.loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(model.loss.d.copy(), args.batch_size)
        training_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        losses.reset()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                model.vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                model.vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                model.vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += model.vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            losses.update(loss_tmp.data.copy(), 1)
        validation_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                best_epoch = epoch
                # save best model
                if args.umx_train:
                    nn.save_parameters(os.path.join(args.output,
                                                    'best_umx.h5'))
                else:
                    nn.save_parameters(
                        os.path.join(args.output, 'best_xumx.h5'))

        if args.umx_train:
            # Early stopping for UMX after `args.patience` (140) number of epochs
            if stop:
                print("Apply Early Stopping")
                break

dataset_loaders = {
    'train':DataLoader(Controller(train_df),
                       DICOMPreprocessor(augment=True)),
    'val':DataLoader(Controller(val_df),
                     DICOMPreprocessor(augment=True))
}


dataset_sizes = {
    'train':dataset_loaders['train'].shape(),
    'val':dataset_loaders['val'].shape()
}


RLRP_agent = ReduceLROnPlateau('min')


num_epochs = 5


best_model = train_model(args,
                         model,
                         criterion,
                         dataset_loaders,
                         dataset_sizes,
                         RLRP_agent,
                         num_epochs)

print(best_model)
Example #12
0
def main():
    logger.info("Starting training\n\n")
    sys.stdout.flush()
    args = get_args()
    snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth"
    best_model_path = args.snapshot_prefix + "-best_model.pth"

    line_img_transforms = []

    #if args.num_in_channels == 3:
    #    line_img_transforms.append(imagetransforms.ConvertColor())

    # Always convert color for the augmentations to work (for now)
    # Then alter convert back to grayscale if needed
    line_img_transforms.append(imagetransforms.ConvertColor())

    # Data augmentations (during training only)
    if args.daves_augment:
        line_img_transforms.append(daves_augment.ImageAug())

    if args.synth_input:

        # Randomly rotate image from -2 degrees to +2 degrees
        line_img_transforms.append(
            imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2,
                                                                        2)))

        # Choose one of methods to blur/pixel-ify image  (or don't and choose identity)
        line_img_transforms.append(
            imagetransforms.PickOne([
                imagetransforms.TessBlockConv(kernel_val=1, bias_val=1),
                imagetransforms.TessBlockConv(rand=True),
                imagetransforms.Identity(),
            ]))

        aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5)
        line_img_transforms.append(
            imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x)))

        # With some probability, choose one of:
        #   Grayscale:  convert to grayscale and add back into color-image with random alpha
        #   Emboss:  Emboss image with random strength
        #   Invert:  Invert colors of image per-channel
        aug_gray = iaa.Grayscale(alpha=(0.0, 1.0))
        aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0))
        aug_invert = iaa.Invert(1, per_channel=True)
        aug_invert2 = iaa.Invert(0.1, per_channel=False)
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.PickOne([
                    lambda x: aug_gray.augment_image(x),
                    lambda x: aug_emboss.augment_image(x),
                    lambda x: aug_invert.augment_image(x),
                    lambda x: aug_invert2.augment_image(x)
                ])))

        # Randomly try to crop close to top/bottom and left/right of lines
        # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off)

        if args.tight_crop:
            # To make sure padding is reasonably consistent, we first rsize image to target line height
            # Then add padding to this version of image
            # Below it will get resized again to target line height
            line_img_transforms.append(
                imagetransforms.Randomize(
                    0.9,
                    imagetransforms.Compose([
                        imagetransforms.Scale(new_h=args.line_height),
                        imagetransforms.PadRandom(pxl_max_horizontal=30,
                                                  pxl_max_vertical=10)
                    ])))

        else:
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropHorizontal(.05)))
            line_img_transforms.append(
                imagetransforms.Randomize(0.2,
                                          imagetransforms.CropVertical(.1)))

        #line_img_transforms.append(imagetransforms.Randomize(0.2,
        #                                                     imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)])
        #                                                     ))

    # Make sure to do resize after degrade step above
    line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height))

    if args.cvtGray:
        line_img_transforms.append(imagetransforms.ConvertGray())

    # Only do for grayscale
    if args.num_in_channels == 1:
        line_img_transforms.append(imagetransforms.InvertBlackWhite())

    if args.stripe:
        line_img_transforms.append(
            imagetransforms.Randomize(
                0.3,
                imagetransforms.AddRandomStripe(val=0,
                                                strip_width_from=1,
                                                strip_width_to=4)))

    line_img_transforms.append(imagetransforms.ToTensor())

    line_img_transforms = imagetransforms.Compose(line_img_transforms)

    # Setup cudnn benchmarks for faster code
    torch.backends.cudnn.benchmark = False

    if len(args.datadir) == 1:
        train_dataset = OcrDataset(args.datadir[0], "train",
                                   line_img_transforms)
        validation_dataset = OcrDataset(args.datadir[0], "validation",
                                        line_img_transforms)
    else:
        train_dataset = OcrDatasetUnion(args.datadir, "train",
                                        line_img_transforms)
        validation_dataset = OcrDatasetUnion(args.datadir, "validation",
                                             line_img_transforms)

    if args.test_datadir is not None:
        if args.test_outdir is None:
            print(
                "Error, must specify both --test-datadir and --test-outdir together"
            )
            sys.exit(1)

        if not os.path.exists(args.test_outdir):
            os.makedirs(args.test_outdir)

        line_img_transforms_test = imagetransforms.Compose([
            imagetransforms.Scale(new_h=args.line_height),
            imagetransforms.ToTensor()
        ])
        test_dataset = OcrDataset(args.test_datadir, "test",
                                  line_img_transforms_test)

    n_epochs = args.nepochs
    lr_alpha = args.lr
    snapshot_every_n_iterations = args.snapshot_num_iterations

    if args.load_from_snapshot is not None:
        model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot)
        print(
            "Overriding automatically learned alphabet with pre-saved model alphabet"
        )
        if len(args.datadir) == 1:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
        else:
            train_dataset.alphabet = model.alphabet
            validation_dataset.alphabet = model.alphabet
            for ds in train_dataset.datasets:
                ds.alphabet = model.alphabet
            for ds in validation_dataset.datasets:
                ds.alphabet = model.alphabet

    else:
        model = CnnOcrModel(num_in_channels=args.num_in_channels,
                            input_line_height=args.line_height,
                            rds_line_height=args.rds_line_height,
                            lstm_input_dim=args.lstm_input_dim,
                            num_lstm_layers=args.num_lstm_layers,
                            num_lstm_hidden_units=args.num_lstm_units,
                            p_lstm_dropout=0.5,
                            alphabet=train_dataset.alphabet,
                            multigpu=True)

    # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model
    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  num_workers=4,
                                  sampler=GroupedSampler(train_dataset,
                                                         rand=True),
                                  collate_fn=SortByWidthCollater,
                                  pin_memory=True,
                                  drop_last=True)

    if args.max_val_size > 0:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset,
                                               max_items=args.max_val_size,
                                               fixed_rand=True),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)
    else:
        validation_dataloader = DataLoader(validation_dataset,
                                           args.batch_size,
                                           num_workers=0,
                                           sampler=GroupedSampler(
                                               validation_dataset, rand=False),
                                           collate_fn=SortByWidthCollater,
                                           pin_memory=False,
                                           drop_last=False)

    if args.test_datadir is not None:
        test_dataloader = DataLoader(test_dataset,
                                     args.batch_size,
                                     num_workers=0,
                                     sampler=GroupedSampler(test_dataset,
                                                            rand=False),
                                     collate_fn=SortByWidthCollater,
                                     pin_memory=False,
                                     drop_last=False)

    # Set training mode on all sub-modules
    model.train()

    ctc_loss = CTCLoss().cuda()

    iteration = 0
    best_val_wer = float('inf')

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr_alpha,
                                 weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=args.patience,
                                  min_lr=args.min_lr)
    wer_array = []
    cer_array = []
    loss_array = []
    lr_points = []
    iteration_points = []

    epoch_size = len(train_dataloader)

    do_test_write = False
    for epoch in range(1, n_epochs + 1):
        epoch_start = datetime.datetime.now()

        # First modify main OCR model
        for batch in train_dataloader:
            sys.stdout.flush()
            iteration += 1
            iteration_start = datetime.datetime.now()

            loss = train(batch, model, ctc_loss, optimizer)

            elapsed_time = datetime.datetime.now() - iteration_start
            loss = loss / args.batch_size

            loss_array.append(loss)

            logger.info(
                "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s"
                % (iteration, iteration % epoch_size, epoch_size, epoch, loss,
                   pretty_print_timespan(elapsed_time)))

            # Only turn on test-on-testset when cer is starting to get non-random
            if iteration % snapshot_every_n_iterations == 0:
                logger.info("Testing on validation set")
                val_loss, val_cer, val_wer = test_on_val(
                    validation_dataloader, model, ctc_loss)

                if val_cer < 0.5:
                    do_test_write = True

                if args.test_datadir is not None and (
                        iteration % snapshot_every_n_iterations
                        == 0) and do_test_write:
                    out_hyp_outdomain_file = os.path.join(
                        args.test_outdir,
                        "hyp-%07d.outdomain.utf8" % iteration)
                    out_hyp_indomain_file = os.path.join(
                        args.test_outdir, "hyp-%07d.indomain.utf8" % iteration)
                    out_meta_file = os.path.join(args.test_outdir,
                                                 "hyp-%07d.meta" % iteration)
                    test_on_val_writeout(test_dataloader, model,
                                         out_hyp_outdomain_file)
                    test_on_val_writeout(validation_dataloader, model,
                                         out_hyp_indomain_file)
                    with open(out_meta_file, 'w') as fh_out:
                        fh_out.write("%d,%f,%f,%f\n" %
                                     (iteration, val_cer, val_wer, val_loss))

                # Reduce learning rate on plateau
                early_exit = False
                lowered_lr = False
                if scheduler.step(val_wer):
                    lowered_lr = True
                    lr_points.append(iteration / snapshot_every_n_iterations)
                    if scheduler.finished:
                        early_exit = True

                    # for bookeeping only
                    lr_alpha = max(lr_alpha * scheduler.factor,
                                   scheduler.min_lr)

                logger.info(
                    "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" %
                    (val_loss, val_cer, val_wer))

                torch.save(
                    {
                        'iteration': iteration,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'model_hyper_params': model.get_hyper_params(),
                        'rtl': args.rtl,
                        'cur_lr': lr_alpha,
                        'val_loss': val_loss,
                        'val_cer': val_cer,
                        'val_wer': val_wer,
                        'line_height': args.line_height
                    }, snapshot_path)

                # plotting lr_change on wer, cer and loss.
                wer_array.append(val_wer)
                cer_array.append(val_cer)
                iteration_points.append(iteration /
                                        snapshot_every_n_iterations)

                if val_wer < best_val_wer:
                    logger.info(
                        "Best model so far, copying snapshot to best model file"
                    )
                    best_val_wer = val_wer
                    shutil.copyfile(snapshot_path, best_model_path)

                logger.info("Running WER: %s" % str(wer_array))
                logger.info("Done with validation, moving on.")

                if early_exit:
                    logger.info("Early exit")
                    sys.exit(0)

                if lowered_lr:
                    logger.info(
                        "Switching to best model parameters before continuing with lower LR"
                    )
                    weights = torch.load(best_model_path)
                    model.load_state_dict(weights['state_dict'])

        elapsed_time = datetime.datetime.now() - epoch_start
        logger.info("\n------------------")
        logger.info("Done with epoch, elapsed time = %s" %
                    pretty_print_timespan(elapsed_time))
        logger.info("------------------\n")

    #writer.close()
    logger.info("Done.")
Example #13
0
def train(args, model, processor):
    tokenizer = BertTokenizer.from_pretrained(
        './BERT_model/bert_pretrain/vocab.txt')

    train_dataset = load_and_cache_examples(args, processor, data_type='train')
    train_loader = DatasetLoader(data=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 seed=args.seed,
                                 sort=True,
                                 vocab=processor.vocab,
                                 label2id=args.label2id,
                                 tokenizer=tokenizer)
    # train_loader = DatasetLoader(data=train_dataset, batch_size=args.batch_size,
    #                              shuffle=False, seed=args.seed, sort=True,
    #                              vocab=processor.vocab, label2id=args.label2id)
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters, lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.5,
                                  patience=3,
                                  verbose=1,
                                  epsilon=1e-4,
                                  cooldown=0,
                                  min_lr=0,
                                  eps=1e-8)

    train_metric = SeqEntityScore(args.id2label, markup=args.markup)
    best_f1 = 0
    for epoch in range(1, 1 + args.epochs):
        strat_epoch_time = time.time()
        logger.info(f"Epoch {epoch}/{args.epochs}")
        #pbar = ProgressBar(n_total=len(train_loader), desc='Training') #进度条样式
        train_loss = AverageMeter()
        model.train()
        assert model.training
        for step, batch in enumerate(train_loader):
            strat_batch_time = time.time()
            input_ids, input_mask, input_tags, input_lens = batch
            input_ids = input_ids.to(args.device)
            input_mask = input_mask.to(args.device)
            input_tags = input_tags.to(args.device)

            features, loss = model.forward_loss(input_ids, input_mask,
                                                input_lens, input_tags)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
            optimizer.step()
            optimizer.zero_grad()

            # pbar(step=step, info={'loss': loss.item()})
            train_loss.update(loss.item(), n=1)

            tags, _ = model.crf._obtain_labels(features, args.id2label,
                                               input_lens)
            input_tags = input_tags.cpu().numpy()
            target = [
                input_[:len_] for input_, len_ in zip(input_tags, input_lens)
            ]

            pre_train = train_metric.compute_train_pre(label_paths=target,
                                                       pred_paths=tags)
            logger.info(
                f'time: {time.time() - strat_batch_time:.1f}  train_loss: {loss.item():.4f}  train_pre: {pre_train:.4f}'
            )
        print(" ")
        logger.info(f'train_total_time: {time.time() - strat_epoch_time}')

        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()  # 释放显存

        strat_eval_time = time.time()
        eval_f1 = evaluate(args, model, processor)

        show_info = f'eval_time: {time.time() - strat_eval_time:.1f}   train_avg_loss: {train_loss.avg:.4f}  eval_f1: {eval_f1:.4f} '
        logger.info(show_info)
        scheduler.epoch_step(eval_f1, epoch)

        if eval_f1 > best_f1:
            # Epoch 1: eval_f1 improved from 0 to 0.4023105674481821
            logger.info(
                f"\nEpoch {epoch}: eval_f1 improved from {best_f1} to {eval_f1}"
            )

            best_f1 = eval_f1

            model_stat_dict = model.state_dict()
            state = {
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model_stat_dict
            }
            model_path = args.output_dir / 'best-model.bin'
            torch.save(state, str(model_path))