Exemplo n.º 1
0
from utils import save_checkpoint
from summary import Scalar, Image3D
import itertools
import numpy as np
import imageio

import itertools
import shutil
import torch.backends.cudnn as cudnn

#from utils import logging
from torch.autograd import Variable
from sklearn import metrics

if __name__ == '__main__':
    FG = train_args()
    vis = Visdom(port=FG.vis_port, env=str(FG.vis_env))
    vis.text(argument_report(FG, end='<br>'), win='config')

    # torch setting
    device = torch.device('cuda:{}'.format(FG.devices[0]))
    torch.cuda.set_device(FG.devices[0])
    timer = SimpleTimer()

    FG.save_dir = str(FG.vis_env)
    if not os.path.exists(FG.save_dir):
        os.makedirs(FG.save_dir)

    printers = dict(lr=Scalar(vis,
                              'lr',
                              opts=dict(showlegend=True,
Exemplo n.º 2
0
def main():
    # option flags
    FLG = train_args()

    # torch setting
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])

    # create summary and report the option
    visenv = FLG.model
    summary = Summary(port=39199, env=visenv)
    summary.viz.text(argument_report(FLG, end='<br>'),
                     win='report' + str(FLG.running_fold))
    train_report = ScoreReport()
    valid_report = ScoreReport()
    timer = SimpleTimer()
    fold_str = 'fold' + str(FLG.running_fold)
    best_score = dict(epoch=0, loss=1e+100, accuracy=0)

    #### create dataset ###
    # kfold split
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))
    trainblock, validblock, ratio = fold_split(
        FLG.fold, FLG.running_fold, FLG.labels,
        np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)

    def _dataset(block, transform):
        return ADNIDataset(FLG.labels,
                           pjoin(FLG.data_root, FLG.modal),
                           block,
                           target_dict,
                           transform=transform)

    # create train set
    trainset = _dataset(trainblock, transform_presets(FLG.augmentation))

    # create normal valid set
    validset = _dataset(
        validblock,
        transform_presets('nine crop' if FLG.augmentation ==
                          'random crop' else 'no augmentation'))

    # each loader
    trainloader = DataLoader(trainset,
                             batch_size=FLG.batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)
    validloader = DataLoader(validset, num_workers=4, pin_memory=True)

    # data check
    # for image, _ in trainloader:
    #     summary.image3d('asdf', image)

    # create model
    def kaiming_init(tensor):
        return kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels),
                      name=FLG.model,
                      weights_initializer=kaiming_init)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    else:
        raise NotImplementedError(FLG.model)

    print_model_parameters(model)
    model = torch.nn.DataParallel(model, FLG.devices)
    model.to(device)

    # criterion
    train_criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor(
        list(map(lambda x: x * 2, reversed(ratio))))).to(device)
    valid_criterion = torch.nn.CrossEntropyLoss().to(device)

    # TODO resume
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLG.lr,
                                 weight_decay=FLG.l2_decay)
    # scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, FLG.lr_gamma)

    start_epoch = 0
    global_step = start_epoch * len(trainloader)
    pbar = None
    for epoch in range(1, FLG.max_epoch + 1):
        timer.tic()
        scheduler.step()
        summary.scalar('lr',
                       fold_str,
                       epoch - 1,
                       optimizer.param_groups[0]['lr'],
                       ytickmin=0,
                       ytickmax=FLG.lr)

        # train()
        torch.set_grad_enabled(True)
        model.train(True)
        train_report.clear()
        if pbar is None:
            pbar = tqdm(total=len(trainloader) * FLG.validation_term,
                        desc='Epoch {:<3}-{:>3} train'.format(
                            epoch, epoch + FLG.validation_term - 1))
        for images, targets in trainloader:
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            optimizer.zero_grad()

            outputs = model(images)
            loss = train_criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_report.update_true(targets)
            train_report.update_score(F.softmax(outputs, dim=1))

            summary.scalar('loss',
                           'train ' + fold_str,
                           global_step / len(trainloader),
                           loss.item(),
                           ytickmin=0,
                           ytickmax=1)

            pbar.update()
            global_step += 1

        if epoch % FLG.validation_term != 0:
            timer.toc()
            continue
        pbar.close()

        # valid()
        torch.set_grad_enabled(False)
        model.eval()
        valid_report.clear()
        pbar = tqdm(total=len(validloader),
                    desc='Epoch {:>3} valid'.format(epoch))
        for images, targets in validloader:
            true = targets
            npatchs = 1
            if len(images.shape) == 6:
                _, npatchs, c, x, y, z = images.shape
                images = images.view(-1, c, x, y, z)
                targets = torch.cat([targets
                                     for _ in range(npatchs)]).squeeze()
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            output = model(images)
            loss = valid_criterion(output, targets)

            valid_report.loss += loss.item()

            if npatchs == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)
            valid_report.update_true(true)
            valid_report.update_score(score)

            pbar.update()
        pbar.close()

        # report
        vloss = valid_report.loss / len(validloader)
        summary.scalar('accuracy',
                       'train ' + fold_str,
                       epoch,
                       train_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        summary.scalar('loss',
                       'valid ' + fold_str,
                       epoch,
                       vloss,
                       ytickmin=0,
                       ytickmax=0.8)
        summary.scalar('accuracy',
                       'valid ' + fold_str,
                       epoch,
                       valid_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        is_best = False
        if best_score['loss'] > vloss:
            best_score['loss'] = vloss
            best_score['epoch'] = epoch
            best_score['accuracy'] = valid_report.accuracy
            is_best = True

        print('Best Epoch {}: validation loss {} accuracy {}'.format(
            best_score['epoch'], best_score['loss'], best_score['accuracy']))

        # save
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()

        save_checkpoint(
            dict(epoch=epoch,
                 best_score=best_score,
                 state_dict=state_dict,
                 optimizer_state_dict=optimizer.state_dict()),
            FLG.checkpoint_root, FLG.running_fold, FLG.model, is_best)
        pbar = None
        timer.toc()
        print('Time elapse {}h {}m {}s'.format(*timer.total()))
        cluster_cfg = cluster_cfg._replace(dist_url=get_init_file().as_uri())
        train_cfg = train_cfg._replace(**grid_data)

        run_name = f"{train_cfg.prefix}"
        for k, v in grid_data.items():
            run_name += "-" + save_key[k](v)
        train_cfg = train_cfg._replace(
            output_dir=os.path.join(log_dir, run_name))

        # Chronos needs a different job name each time
        executor.update_parameters(name=f"sweep_{i:02d}_{uuid.uuid4().hex}")
        trainer = Trainer(train_cfg, cluster_cfg)
        job = executor.submit(trainer)
        jobs.append(job)
        print(
            f"Run {i:02d} submitted with train cfg: {train_cfg}, cluster cfg: {cluster_cfg}"
        )
    print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}")

    # Wait for the master's results of each job
    results = [job.task(0).result() for job in jobs]
    print(f"Jobs results: {results}")
    best_job = np.argmax(results)
    print(
        f"Best configuration: {hyper_parameters[best_job]} (val acc = {results[best_job]:.1%})"
    )


if __name__ == "__main__":
    args = train_args()
    grid_search(args)
Exemplo n.º 4
0
def main():
    args = train_args()

    if args.fp16:
        apex.amp.register_half_function(torch, 'einsum')

    date_curr = date.today().strftime("%m-%d-%Y")
    model_name = f"{args.prefix}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-lr{args.learning_rate}-decay{args.weight_decay}-warm{args.warmup_ratio}-{args.model_name}"
    args.output_dir = os.path.join(args.output_dir, date_curr, model_name)
    tb_logger = SummaryWriter(
        os.path.join(args.output_dir.replace("logs", "tflogs")))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        print(
            f"output directory {args.output_dir} already exists and is not empty."
        )
    os.makedirs(args.output_dir, exist_ok=True)

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
            logging.StreamHandler()
        ])
    logger = logging.getLogger(__name__)
    logger.info(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    bert_config = AutoConfig.from_pretrained(args.model_name)
    if args.momentum:
        model = MomentumRetriever(bert_config, args)
    elif "roberta" in args.model_name:
        model = RobertaRetrieverSingle(bert_config, args)
    else:
        model = BertRetrieverSingle(bert_config, args)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    collate_fc = partial(sp_collate, pad_id=tokenizer.pad_token_id)

    if args.do_train and args.max_c_len > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (args.max_c_len, bert_config.max_position_embeddings))

    if "fever" in args.predict_file:
        eval_dataset = FeverSingleDataset(tokenizer, args.predict_file,
                                          args.max_q_len, args.max_c_len)
    else:
        eval_dataset = SPDataset(tokenizer, args.predict_file, args.max_q_len,
                                 args.max_c_len)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=args.predict_batch_size,
                                 collate_fn=collate_fc,
                                 pin_memory=True,
                                 num_workers=args.num_workers)
    logger.info(f"Num of dev batches: {len(eval_dataloader)}")

    if args.init_checkpoint != "":
        model = load_saved(model, args.init_checkpoint)

    model.to(device)
    print(
        f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    if args.do_train:
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = Adam(optimizer_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)

        if args.fp16:
            model, optimizer = apex.amp.initialize(
                model, optimizer, opt_level=args.fp16_opt_level)
    else:
        if args.fp16:
            model = apex.amp.initialize(model, opt_level=args.fp16_opt_level)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.do_train:
        global_step = 0  # gradient update step
        batch_step = 0  # forward batch count
        best_mrr = 0
        train_loss_meter = AverageMeter()
        model.train()
        if "fever" in args.predict_file:
            train_dataset = FeverSingleDataset(tokenizer,
                                               args.train_file,
                                               args.max_q_len,
                                               args.max_c_len,
                                               train=True)
        else:
            train_dataset = SPDataset(tokenizer,
                                      args.train_file,
                                      args.max_q_len,
                                      args.max_c_len,
                                      train=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      pin_memory=True,
                                      collate_fn=collate_fc,
                                      num_workers=args.num_workers,
                                      shuffle=True)

        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        warmup_steps = t_total * args.warmup_ratio
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=t_total)

        logger.info('Start training....')
        for epoch in range(int(args.num_train_epochs)):

            for batch in tqdm(train_dataloader):
                batch_step += 1
                batch = move_to_cuda(batch)
                loss = loss_single(model, batch, args.momentum)

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                train_loss_meter.update(loss.item())

                if (batch_step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            apex.amp.master_params(optimizer),
                            args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    tb_logger.add_scalar('batch_train_loss', loss.item(),
                                         global_step)
                    tb_logger.add_scalar('smoothed_train_loss',
                                         train_loss_meter.avg, global_step)

                    if args.eval_period != -1 and global_step % args.eval_period == 0:
                        mrr = predict(args, model, eval_dataloader, device,
                                      logger)
                        logger.info(
                            "Step %d Train loss %.2f MRR %.2f on epoch=%d" %
                            (global_step, train_loss_meter.avg, mrr * 100,
                             epoch))

                        if best_mrr < mrr:
                            logger.info(
                                "Saving model with best MRR %.2f -> MRR %.2f on epoch=%d"
                                % (best_mrr * 100, mrr * 100, epoch))
                            torch.save(
                                model.state_dict(),
                                os.path.join(args.output_dir,
                                             f"checkpoint_best.pt"))
                            model = model.to(device)
                            best_mrr = mrr

            mrr = predict(args, model, eval_dataloader, device, logger)
            logger.info("Step %d Train loss %.2f MRR %.2f on epoch=%d" %
                        (global_step, train_loss_meter.avg, mrr * 100, epoch))
            tb_logger.add_scalar('dev_mrr', mrr * 100, epoch)
            if best_mrr < mrr:
                torch.save(
                    model.state_dict(),
                    os.path.join(args.output_dir, f"checkpoint_last.pt"))
                logger.info(
                    "Saving model with best MRR %.2f -> MRR %.2f on epoch=%d" %
                    (best_mrr * 100, mrr * 100, epoch))
                torch.save(
                    model.state_dict(),
                    os.path.join(args.output_dir, f"checkpoint_best.pt"))
                model = model.to(device)
                best_mrr = mrr

        logger.info("Training finished!")

    elif args.do_predict:
        acc = predict(args, model, eval_dataloader, device, logger)
        logger.info(f"test performance {acc}")
def main():
    args = train_args()
    if args.fp16:
        import apex
        apex.amp.register_half_function(torch, 'einsum')
    date_curr = date.today().strftime("%m-%d-%Y")
    model_name = f"{args.prefix}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-lr{args.learning_rate}-decay{args.weight_decay}"
    args.output_dir = os.path.join(args.output_dir, date_curr, model_name)
    tb_logger = SummaryWriter(
        os.path.join(args.output_dir.replace("logs", "tflogs")))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        print(
            f"output directory {args.output_dir} already exists and is not empty."
        )
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
            logging.StreamHandler()
        ])
    logger = logging.getLogger(__name__)
    logger.info(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    bert_config = AutoConfig.from_pretrained(args.model_name)
    model = RankModel(bert_config, args)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    collate_fc = partial(rank_collate, pad_id=tokenizer.pad_token_id)
    if args.do_train and args.max_seq_len > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (args.max_seq_len, bert_config.max_position_embeddings))

    eval_dataset = RankingDataset(tokenizer, args.predict_file,
                                  args.max_seq_len, args.max_q_len)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=args.predict_batch_size,
                                 collate_fn=collate_fc,
                                 pin_memory=True,
                                 num_workers=args.num_workers)
    logger.info(f"Num of dev batches: {len(eval_dataloader)}")

    if args.init_checkpoint != "":
        model = load_saved(model, args.init_checkpoint)

    model.to(device)
    print(
        f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    if args.do_train:
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)

        if args.fp16:
            from apex import amp
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)
    else:
        if args.fp16:
            from apex import amp
            model = amp.initialize(model, opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.do_train:
        global_step = 0  # gradient update step
        batch_step = 0  # forward batch count
        best_acc = 0
        train_loss_meter = AverageMeter()
        model.train()
        train_dataset = RankingDataset(tokenizer,
                                       args.train_file,
                                       args.max_seq_len,
                                       args.max_q_len,
                                       train=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      pin_memory=True,
                                      collate_fn=collate_fc,
                                      num_workers=args.num_workers,
                                      shuffle=True)

        logger.info('Start training....')
        for epoch in range(int(args.num_train_epochs)):
            for batch in tqdm(train_dataloader):
                batch_step += 1
                batch_inputs = move_to_cuda(batch["net_inputs"])
                loss = model(batch_inputs)

                if n_gpu > 1:
                    loss = loss.mean()

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                train_loss_meter.update(loss.item())
                tb_logger.add_scalar('batch_train_loss', loss.item(),
                                     global_step)
                tb_logger.add_scalar('smoothed_train_loss',
                                     train_loss_meter.avg, global_step)

                if (batch_step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    global_step += 1

                    if args.eval_period != -1 and global_step % args.eval_period == 0:
                        acc = predict(args, model, eval_dataloader, device,
                                      logger)
                        logger.info(
                            "Step %d Train loss %.2f acc %.2f on epoch=%d" %
                            (global_step, train_loss_meter.avg, acc * 100,
                             epoch))

                        # save most recent model
                        torch.save(
                            model.state_dict(),
                            os.path.join(args.output_dir,
                                         f"checkpoint_last.pt"))

                        if best_acc < acc:
                            logger.info(
                                "Saving model with best acc %.2f -> acc %.2f on epoch=%d"
                                % (best_acc * 100, acc * 100, epoch))
                            torch.save(
                                model.state_dict(),
                                os.path.join(args.output_dir,
                                             f"checkpoint_best.pt"))
                            model = model.to(device)
                            best_acc = acc

            acc = predict(args, model, eval_dataloader, device, logger)
            logger.info("Step %d Train loss %.2f acc %.2f on epoch=%d" %
                        (global_step, train_loss_meter.avg, acc * 100, epoch))
            tb_logger.add_scalar('dev_acc', acc * 100, epoch)
            torch.save(model.state_dict(),
                       os.path.join(args.output_dir, f"checkpoint_last.pt"))

            if best_acc < acc:
                logger.info(
                    "Saving model with best acc %.2f -> acc %.2f on epoch=%d" %
                    (best_acc * 100, acc * 100, epoch))
                torch.save(
                    model.state_dict(),
                    os.path.join(args.output_dir, f"checkpoint_best.pt"))
                best_acc = acc

        logger.info("Training finished!")

    elif args.do_predict:
        acc = predict(args, model, eval_dataloader, device, logger)
        logger.info(f"test performance {acc}")
Exemplo n.º 6
0
            normal_name = '%s/epoch_%02d_loss_%.4f_loss_l_%.4f_loss_c_%.4f.pth' % \
                             (self.args.save_to, self.result['epoch'], sum(lossinfo), lossinfo[0], lossinfo[1])
            shutil.copy(save_name, normal_name)

    def _training(self):

        self.result['min_loss'] = 1e5
        for epoch in range(self.args.start_epoch, self.args.end_epoch + 1):

            start_time = time.time()
            self.result['epoch'] = epoch
            loss = self._train_one_epoch()
            self.model['scheduler'].step()
            finish_time = time.time()
            print('single epoch costs %.4f mins' %
                  ((finish_time - start_time) / 60))
            self._save_weights(loss)
            if self.args.is_debug:
                break

    def main_runner(self):
        self._model_loader()
        self._data_loader()
        self._training()


if __name__ == "__main__":

    fas = FaceBoxesTrainer(args=train_args())
    fas.main_runner()