Exemple #1
0
def train_iteration(model, sync_barrier, is_master_proc, device_qty, loss_obj,
                    train_params, max_train_qty, valid_run,
                    valid_qrel_filename, optimizer, scheduler, dataset,
                    train_pairs, qrels, validation_timer, valid_run_dir,
                    valid_scores_holder, save_last_snapshot_every_k_batch,
                    model_out_dir):

    clean_memory(train_params.device_name)

    model.train()
    total_loss = 0.
    total_prev_qty = total_qty = 0.  # This is a total number of records processed, it can be different from
    # the total number of training pairs

    batch_size = train_params.batch_size

    optimizer.zero_grad()

    if train_params.print_grads:
        print('Gradient sums before training')
        for k, v in model.named_parameters():
            print(
                k, 'None' if v.grad is None else torch.sum(
                    torch.norm(v.grad, dim=-1, p=2)))

    lr_desc = get_lr_desc(optimizer)

    batch_id = 0
    snap_id = 0

    if is_master_proc:

        utils.sync_out_streams()

        pbar = tqdm('training',
                    total=max_train_qty,
                    ncols=80,
                    desc=None,
                    leave=False)
    else:
        pbar = None

    for record in data.iter_train_pairs(model, train_params.device_name,
                                        dataset, train_pairs,
                                        train_params.shuffle_train, qrels,
                                        train_params.backprop_batch_size,
                                        train_params.max_query_len,
                                        train_params.max_doc_len):
        scores = model(record['query_tok'], record['query_mask'],
                       record['doc_tok'], record['doc_mask'])
        count = len(record['query_id']) // 2
        scores = scores.reshape(count, 2)
        loss = loss_obj.compute(scores)
        loss.backward()
        total_qty += count

        if train_params.print_grads:
            print(f'Records processed {total_qty} Gradient sums:')
            for k, v in model.named_parameters():
                print(
                    k, 'None' if v.grad is None else torch.sum(
                        torch.norm(v.grad, dim=-1, p=2)))

        total_loss += loss.item()

        if total_qty - total_prev_qty >= batch_size:
            if is_master_proc:
                validation_timer.increment(total_qty - total_prev_qty)

            #print(total, 'optimizer step!')
            optimizer.step()
            optimizer.zero_grad()
            total_prev_qty = total_qty

            # Scheduler must make a step in each batch! *AFTER* the optimizer makes an update!
            if scheduler is not None:
                scheduler.step()
                lr_desc = get_lr_desc(optimizer)

            # This must be done in every process, not only in the master process
            if device_qty > 1:
                if batch_id % train_params.batch_sync_qty == 0:
                    try:
                        sync_barrier.wait(BARRIER_WAIT_MODEL_AVERAGE_TIMEOUT)
                    except BrokenBarrierError:
                        raise Exception(
                            'A model parameter synchronization timeout!')

                    avg_model_params(model)

            batch_id += 1

            # We will surely skip batch_id == 0
            if save_last_snapshot_every_k_batch is not None and batch_id % save_last_snapshot_every_k_batch == 0:
                if is_master_proc:
                    os.makedirs(model_out_dir, exist_ok=True)
                    out_tmp = os.path.join(model_out_dir,
                                           f'model.last.{snap_id}')
                    torch.save(model, out_tmp)
                    snap_id += 1

        if pbar is not None:
            pbar.update(count)
            pbar.refresh()
            utils.sync_out_streams()
            pbar.set_description('%s train loss %.5f' %
                                 (lr_desc, total_loss / float(total_qty)))

        while is_master_proc and validation_timer.is_time(
        ) and valid_run_dir is not None:
            model.eval()
            os.makedirs(valid_run_dir, exist_ok=True)
            run_file_name = os.path.join(
                valid_run_dir,
                f'batch_{validation_timer.last_checkpoint()}.run')
            pbar.refresh()
            utils.sync_out_streams()
            score = validate(model,
                             train_params,
                             dataset,
                             valid_run,
                             qrelf=valid_qrel_filename,
                             run_filename=run_file_name)

            pbar.refresh()
            utils.sync_out_streams()
            print(
                f'\n# of batches={validation_timer.total_steps} score={score:.4g}'
            )
            valid_scores_holder[
                f'batch_{validation_timer.last_checkpoint()}'] = score
            utils.save_json(os.path.join(valid_run_dir, "scores.log"),
                            valid_scores_holder)
            model.train()

        if total_qty >= max_train_qty:
            break

    # Final model averaging in the end.

    if device_qty > 1:
        try:
            sync_barrier.wait(BARRIER_WAIT_MODEL_AVERAGE_TIMEOUT)
        except BrokenBarrierError:
            raise Exception('A model parameter synchronization timeout!')

        avg_model_params(model)

    if pbar is not None:
        pbar.close()
        utils.sync_out_streams()

    return total_loss / float(total_qty)
Exemple #2
0
def do_train(sync_barrier, device_qty, master_port, rank, is_master_proc,
             dataset, qrels, qrel_file_name, train_pairs, valid_run,
             valid_run_dir, valid_checkpoints, model_out_dir, model, loss_obj,
             train_params):
    if device_qty > 1:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(master_port)
        dist.init_process_group(utils.PYTORCH_DISTR_BACKEND,
                                rank=rank,
                                world_size=device_qty)

    device_name = train_params.device_name

    if is_master_proc:
        print('Training parameters:')
        print(train_params)
        print('Loss function:', loss_obj.name())

    print('Device name:', device_name)

    model.to(device_name)

    lr = train_params.init_lr
    bert_lr = train_params.init_bert_lr
    epoch_lr_decay = train_params.epoch_lr_decay
    weight_decay = train_params.weight_decay
    momentum = train_params.momentum

    top_valid_score = None

    train_stat = {}

    validation_timer = ValidationTimer(valid_checkpoints)
    valid_scores_holder = dict()
    for epoch in range(train_params.epoch_qty):

        params = [(k, v) for k, v in model.named_parameters()
                  if v.requires_grad]
        non_bert_params = {
            'params': [v for k, v in params if not k.startswith('bert.')]
        }
        bert_params = {
            'params': [v for k, v in params if k.startswith('bert.')],
            'lr': bert_lr
        }

        if train_params.optim == OPT_ADAMW:
            optimizer = torch.optim.AdamW([non_bert_params, bert_params],
                                          lr=lr,
                                          weight_decay=weight_decay)
        elif train_params.optim == OPT_SGD:
            optimizer = torch.optim.SGD([non_bert_params, bert_params],
                                        lr=lr,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        else:
            raise Exception('Unsupported optimizer: ' + train_params.optim)

        bpte = train_params.batches_per_train_epoch
        max_train_qty = data.train_item_qty(
            train_pairs) if bpte <= 0 else bpte * train_params.batch_size

        lr_steps = int(math.ceil(max_train_qty / train_params.batch_size))
        scheduler = None
        if train_params.warmup_pct:
            if is_master_proc:
                print('Using a scheduler with a warm-up for %f steps' %
                      train_params.warmup_pct)
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                total_steps=lr_steps,
                max_lr=[lr, bert_lr],
                anneal_strategy='linear',
                pct_start=train_params.warmup_pct)
        if is_master_proc:
            print('Optimizer', optimizer)

        start_time = time.time()

        loss = train_iteration(model=model,
                               sync_barrier=sync_barrier,
                               is_master_proc=is_master_proc,
                               device_qty=device_qty,
                               loss_obj=loss_obj,
                               train_params=train_params,
                               max_train_qty=max_train_qty,
                               valid_run=valid_run,
                               valid_qrel_filename=qrel_file_name,
                               optimizer=optimizer,
                               scheduler=scheduler,
                               dataset=dataset,
                               train_pairs=train_pairs,
                               qrels=qrels,
                               validation_timer=validation_timer,
                               valid_run_dir=valid_run_dir,
                               valid_scores_holder=valid_scores_holder,
                               save_last_snapshot_every_k_batch=train_params.
                               save_last_snapshot_every_k_batch,
                               model_out_dir=model_out_dir)

        end_time = time.time()

        if is_master_proc:

            if train_params.save_epoch_snapshots:
                print('Saving the model epoch snapshot')
                torch.save(model, os.path.join(model_out_dir,
                                               f'model.{epoch}'))

            os.makedirs(model_out_dir, exist_ok=True)

            print(
                f'train epoch={epoch} loss={loss:.3g} lr={lr:g} bert_lr={bert_lr:g}'
            )

            utils.sync_out_streams()

            valid_score = validate(model,
                                   train_params,
                                   dataset,
                                   valid_run,
                                   qrelf=qrel_file_name,
                                   run_filename=os.path.join(
                                       model_out_dir, f'{epoch}.run'))

            utils.sync_out_streams()

            print(f'validation epoch={epoch} score={valid_score:.4g}')

            train_stat[epoch] = {
                'loss': loss,
                'score': valid_score,
                'lr': lr,
                'bert_lr': bert_lr,
                'train_time': end_time - start_time
            }

            utils.save_json(os.path.join(model_out_dir, 'train_stat.json'),
                            train_stat)

            if top_valid_score is None or valid_score > top_valid_score:
                top_valid_score = valid_score
                print('new top validation score, saving the whole model')
                torch.save(model, os.path.join(model_out_dir, 'model.best'))

        # We must sync here or else non-master processes would start training and they
        # would timeout on the model averaging barrier. However, the wait time here
        # can be much longer. This is actually quite lame, because validation
        # should instead be split accross GPUs, but validation is usually pretty quick
        # and this should work as a (semi-)temporary fix
        if device_qty > 1:
            try:
                sync_barrier.wait(BARRIER_WAIT_VALIDATION_TIMEOUT)
            except BrokenBarrierError:
                raise Exception('A model parameter synchronization timeout!')

        lr *= epoch_lr_decay
        bert_lr *= epoch_lr_decay
Exemple #3
0
def do_train(sync_barrier, device_qty, master_port, rank, is_master_proc,
             dataset, qrels, qrel_file_name, train_pairs, valid_run,
             model_out_dir, model, loss_obj, train_params):
    if device_qty > 1:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(master_port)
        dist.init_process_group(utils.PYTORCH_DISTR_BACKEND,
                                rank=rank,
                                world_size=device_qty)

    device_name = train_params.device_name

    if is_master_proc:
        print('Training parameters:')
        print(train_params)
        print('Loss function:', loss_obj.name())

    print('Device name:', device_name)

    model.to(device_name)

    lr = train_params.init_lr
    bert_lr = train_params.init_bert_lr
    epoch_lr_decay = train_params.epoch_lr_decay
    weight_decay = train_params.weight_decay
    momentum = train_params.momentum

    top_valid_score = None

    train_stat = {}

    for epoch in range(train_params.epoch_qty):

        params = [(k, v) for k, v in model.named_parameters()
                  if v.requires_grad]
        non_bert_params = {
            'params': [v for k, v in params if not k.startswith('bert.')]
        }
        bert_params = {
            'params': [v for k, v in params if k.startswith('bert.')],
            'lr': bert_lr
        }

        if train_params.optim == OPT_ADAMW:
            optimizer = torch.optim.AdamW([non_bert_params, bert_params],
                                          lr=lr,
                                          weight_decay=weight_decay)
        elif train_params.optim == OPT_SGD:
            optimizer = torch.optim.SGD([non_bert_params, bert_params],
                                        lr=lr,
                                        weight_decay=weight_decay,
                                        momentum=momentum)
        else:
            raise Exception('Unsupported optimizer: ' + train_params.optim)

        bpte = train_params.batches_per_train_epoch
        max_train_qty = data.train_item_qty(
            train_pairs) if bpte <= 0 else bpte * train_params.batch_size

        lr_steps = int(math.ceil(max_train_qty / train_params.batch_size))
        scheduler = None
        if train_params.warmup_pct:
            if is_master_proc:
                print('Using a scheduler with a warm-up for %f steps' %
                      train_params.warmup_pct)
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                total_steps=lr_steps,
                max_lr=[lr, bert_lr],
                anneal_strategy='linear',
                pct_start=train_params.warmup_pct)
        if is_master_proc:
            print('Optimizer', optimizer)

        start_time = time.time()

        loss = train_iteration(model=model,
                               sync_barrier=sync_barrier,
                               is_master_proc=is_master_proc,
                               device_qty=device_qty,
                               loss_obj=loss_obj,
                               train_params=train_params,
                               max_train_qty=max_train_qty,
                               optimizer=optimizer,
                               scheduler=scheduler,
                               dataset=dataset,
                               train_pairs=train_pairs,
                               qrels=qrels,
                               save_last_snapshot_every_k_batch=train_params.
                               save_last_snapshot_every_k_batch,
                               model_out_dir=model_out_dir)

        end_time = time.time()

        if is_master_proc:

            if train_params.save_epoch_snapshots:
                print('Saving the model epoch snapshot')
                torch.save(model, os.path.join(model_out_dir,
                                               f'model.{epoch}'))

            os.makedirs(model_out_dir, exist_ok=True)

            print(
                f'train epoch={epoch} loss={loss:.3g} lr={lr:g} bert_lr={bert_lr:g}'
            )

            utils.sync_out_streams()

            valid_score = validate(model, train_params, dataset, valid_run,
                                   qrel_file_name, epoch, model_out_dir)

            utils.sync_out_streams()

            print(f'validation epoch={epoch} score={valid_score:.4g}')

            train_stat[epoch] = {
                'loss': loss,
                'score': valid_score,
                'lr': lr,
                'bert_lr': bert_lr,
                'train_time': end_time - start_time
            }

            utils.save_json(os.path.join(model_out_dir, 'train_stat.json'),
                            train_stat)

            if top_valid_score is None or valid_score > top_valid_score:
                top_valid_score = valid_score
                print('new top validation score, saving the whole model')
                torch.save(model, os.path.join(model_out_dir, 'model.best'))

        lr *= epoch_lr_decay
        bert_lr *= epoch_lr_decay