예제 #1
0
def train(args):
    is_training = True

    session = tf.compat.v1.Session(config=config.TF_SESSION_CONFIG)
    dataset = AudioWrapper(args, 'train', is_training, session)
    wavs, labels = dataset.get_input_and_output_op()

    model = models.__dict__[args.arch](args)
    model.build(wavs=wavs, labels=labels, is_training=is_training)

    trainer = Trainer(model, session, args, dataset)
    trainer.train()
예제 #2
0
    writer = SummaryWriter(log_dir=os.path.join(log_path, task_name))
    f_log = open(os.path.join(log_path, task_name + ".log"), 'w')

trainer = Trainer(criterion, optimizer, n_class, size_g, size_p, sub_batch_size, mode, lamb_fmreg)
evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode, test)

best_pred = 0.0
print("start training......")
for epoch in range(num_epochs):
    trainer.set_train(model)
    optimizer.zero_grad()
    tbar = tqdm(dataloader_train); train_loss = 0
    for i_batch, sample_batched in enumerate(tbar):
        if evaluation: break
        scheduler(optimizer, i_batch, epoch, best_pred)
        loss = trainer.train(sample_batched, model, global_fixed)
        train_loss += loss.item()
        score_train, score_train_global, score_train_local = trainer.get_scores()
        if mode == 1: tbar.set_description('Train loss: %.3f; global mIoU: %.3f' % (train_loss / (i_batch + 1), np.mean(np.nan_to_num(score_train_global["iou"]))))
        else: tbar.set_description('Train loss: %.3f; agg mIoU: %.3f' % (train_loss / (i_batch + 1), np.mean(np.nan_to_num(score_train["iou"]))))

    score_train, score_train_global, score_train_local = trainer.get_scores()
    trainer.reset_metrics()
    # torch.cuda.empty_cache()

    if epoch % 1 == 0:
        with torch.no_grad():
            model.eval()
            print("evaluating...")

            if test: tbar = tqdm(dataloader_test)
예제 #3
0
파일: main.py 프로젝트: yida2311/HistoDOI
def main(seed=25):
    seed_everything(25)
    device = torch.device('cuda:0')

    # arguments
    args = Args().parse()
    n_class = args.n_class

    img_path_train = args.img_path_train
    mask_path_train = args.mask_path_train
    img_path_val = args.img_path_val
    mask_path_val = args.mask_path_val

    model_path = os.path.join(args.model_path, args.task_name)  # save model
    log_path = args.log_path
    output_path = args.output_path

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    task_name = args.task_name
    print(task_name)
    ###################################
    evaluation = args.evaluation
    test = evaluation and False
    print("evaluation:", evaluation, "test:", test)

    ###################################
    print("preparing datasets and dataloaders......")
    batch_size = args.batch_size
    num_workers = args.num_workers
    config = args.config

    data_time = AverageMeter("DataTime", ':3.3f')
    batch_time = AverageMeter("BatchTime", ':3.3f')

    dataset_train = DoiDataset(img_path_train,
                               config,
                               train=True,
                               root_mask=mask_path_train)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    dataset_val = DoiDataset(img_path_val,
                             config,
                             train=True,
                             root_mask=mask_path_val)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

    ###################################
    print("creating models......")
    model = DoiNet(n_class, config['min_descriptor'] + 6, 4)
    model = create_model_load_weights(model,
                                      evaluation=False,
                                      ckpt_path=args.ckpt_path)
    model.to(device)

    ###################################
    num_epochs = args.epochs
    learning_rate = args.lr

    optimizer = get_optimizer(model, learning_rate=learning_rate)
    scheduler = LR_Scheduler(args.scheduler, learning_rate, num_epochs,
                             len(dataloader_train))
    ##################################
    criterion_node = nn.CrossEntropyLoss()
    criterion_edge = nn.BCELoss()
    alpha = args.alpha

    writer = SummaryWriter(log_dir=log_path + task_name)
    f_log = open(log_path + task_name + ".log", 'w')
    #######################################
    trainer = Trainer(criterion_node,
                      criterion_edge,
                      optimizer,
                      n_class,
                      device,
                      alpha=alpha)
    evaluator = Evaluator(n_class, device)

    best_pred = 0.0
    print("start training......")
    log = task_name + '\n'
    for k, v in args.__dict__.items():
        log += str(k) + ' = ' + str(v) + '\n'
    print(log)
    f_log.write(log)
    f_log.flush()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        tbar = tqdm(dataloader_train)
        train_loss = 0
        train_loss_edge = 0
        train_loss_node = 0

        start_time = time.time()
        for i_batch, sample in enumerate(tbar):
            data_time.update(time.time() - start_time)

            if evaluation:  # evaluation pattern: no training
                break
            scheduler(optimizer, i_batch, epoch, best_pred)
            loss, loss_node, loss_edge = trainer.train(sample, model)
            train_loss += loss.item()
            train_loss_node += loss_node.item()
            train_loss_edge += loss_edge.item()
            train_scores_node, train_scores_edge = trainer.get_scores()

            batch_time.update(time.time() - start_time)
            start_time = time.time()

            if i_batch % 2 == 0:
                tbar.set_description(
                    'Train loss: %.4f (loss_node=%.4f  loss_edge=%.4f); F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                    % (train_loss / (i_batch + 1), train_loss_node /
                       (i_batch + 1), train_loss_edge /
                       (i_batch + 1), train_scores_node["macro_f1"],
                       train_scores_edge["macro_f1"], data_time.avg,
                       batch_time.avg))

        trainer.reset_metrics()
        data_time.reset()
        batch_time.reset()

        if epoch % 1 == 0:
            with torch.no_grad():
                model.eval()
                print("evaluating...")

                tbar = tqdm(dataloader_val)
                start_time = time.time()
                for i_batch, sample in enumerate(tbar):
                    data_time.update(time.time() - start_time)
                    pred_node, pred_edge = evaluator.eval(sample, model)
                    val_scores_node, val_scores_edge = evaluator.get_scores()

                    batch_time.update(time.time() - start_time)
                    tbar.set_description(
                        'F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                        % (val_scores_node["macro_f1"],
                           val_scores_edge["macro_f1"], data_time.avg,
                           batch_time.avg))
                    start_time = time.time()

            data_time.reset()
            batch_time.reset()
            val_scores_node, val_scores_node = evaluator.get_scores()
            evaluator.reset_metrics()

            best_pred = save_model(model, model_path, val_scores_node,
                                   val_scores_edge, alpha, task_name, epoch,
                                   best_pred)
            write_log(f_log, train_scores_node, train_scores_edge,
                      val_scores_node, val_scores_edge, epoch, num_epochs)
            write_summaryWriter(writer, train_loss / len(dataloader_train),
                                optimizer, train_scores_node,
                                train_scores_edge, val_scores_node,
                                val_scores_edge, epoch)

    f_log.close()
예제 #4
0
best_pred = 0.0
print("start training......")

for epoch in range(num_epochs):
    optimizer.zero_grad()
    tbar = tqdm(dataloader_train)
    train_loss = 0
    start_time = time.time()
    for i_batch, sample_batched in enumerate(tbar):
        print(i_batch)
        data_time.update(time.time()-start_time)
        if evaluation:  # evaluation pattern: no training
            break
        scheduler(optimizer, i_batch, epoch, best_pred)
        loss = trainer.train(sample_batched, model)
        train_loss += loss.item()

        score_train = trainer.get_scores()
        precision = score_train['precision']
        mAP = score_train['mAP']


        batch_time.update(time.time()-start_time)
        start_time = time.time()
        tbar.set_description('Train loss: %.3f; precision: %.3f; mAP: %.3f; data time: %.3f; batch time: %.3f'% 
                (train_loss / (i_batch + 1), precision, mAP, data_time.avg, batch_time.avg))
    writer.add_scalar('loss', train_loss/len(tbar), epoch)
    writer.add_scalar('mAP/train', mAP, epoch)
    writer.add_scalar('precision/train', precision, epoch)
    writer.add_scalar('distance/train', score_train['distance'], epoch)