Exemple #1
0
def train(args):

    # Setup Dataloader
    data_loader = get_loader('doc3dbmnic')
    data_path = args.data_path
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Model
    model = get_model(args.arch, n_classes, in_channels=3)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.l_rate,
                                 weight_decay=5e-4,
                                 amsgrad=True)

    # LR Scheduler
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.5,
                                                       patience=3,
                                                       verbose=True)

    # Losses
    MSE = nn.MSELoss()
    loss_fn = nn.L1Loss()
    reconst_loss = recon_lossc.Unwarploss()

    epoch_start = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            # optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            epoch_start = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    # Log file:
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    experiment_name = 'dnetccnl_htan_swat3dmini1kbm_l1_noaug_scratch'  #network_activation(t=[-1,1])_dataset_lossparams_augmentations_trainstart
    log_file_name = os.path.join(args.logdir, experiment_name + '.txt')
    if os.path.isfile(log_file_name):
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w+')

    log_file.write('\n---------------  ' + experiment_name +
                   '  ---------------\n')
    log_file.close()

    # Setup tensorboard for visualization
    if args.tboard:
        # save logs in runs/<experiment_name>
        writer = SummaryWriter(comment=experiment_name)

    best_val_uwarpssim = 99999.0
    best_val_mse = 99999.0
    global_step = 0

    for epoch in range(epoch_start, args.n_epoch):
        avg_loss = 0.0
        avgl1loss = 0.0
        avgrloss = 0.0
        avgssimloss = 0.0
        train_mse = 0.0
        model.train()

        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            optimizer.zero_grad()
            target = model(images[:, 3:, :, :])
            target_nhwc = target.transpose(1, 2).transpose(2, 3)
            l1loss = loss_fn(target_nhwc, labels)
            rloss, ssim, uworg, uwpred = reconst_loss(images[:, :-1, :, :],
                                                      target_nhwc, labels)
            loss = (10.0 * l1loss) + (0.5 * rloss)  #+ (0.3*ssim)
            # loss=l1loss
            avgl1loss += float(l1loss)
            avg_loss += float(loss)
            avgrloss += float(rloss)
            avgssimloss += float(ssim)

            train_mse += MSE(target_nhwc, labels).item()

            loss.backward()
            optimizer.step()
            global_step += 1

            if (i + 1) % 50 == 0:
                avg_loss = avg_loss / 50
                print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, i + 1, len(trainloader),
                       avg_loss))
                avg_loss = 0.0

            if args.tboard and (i + 1) % 20 == 0:
                show_unwarp_tnsboard(global_step, writer, uwpred, uworg, 8,
                                     'Train GT unwarp', 'Train Pred Unwarp')
                writer.add_scalar('BM: L1 Loss/train', avgl1loss / (i + 1),
                                  global_step)
                writer.add_scalar('CB: Recon Loss/train', avgrloss / (i + 1),
                                  global_step)
                writer.add_scalar('CB: SSIM Loss/train', avgssimloss / (i + 1),
                                  global_step)

        avgssimloss = avgssimloss / len(trainloader)
        avgrloss = avgrloss / len(trainloader)
        avgl1loss = avgl1loss / len(trainloader)
        train_mse = train_mse / len(trainloader)
        print("Training L1:%4f" % (avgl1loss))
        print("Training MSE:'{}'".format(train_mse))
        train_losses = [avgl1loss, train_mse, avgrloss, avgssimloss]
        lrate = get_lr(optimizer)
        write_log_file(log_file_name, train_losses, epoch + 1, lrate, 'Train')

        model.eval()
        val_loss = 0.0
        val_l1loss = 0.0
        val_mse = 0.0
        val_rloss = 0.0
        val_ssimloss = 0.0

        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            with torch.no_grad():
                images_val = Variable(images_val.cuda())
                labels_val = Variable(labels_val.cuda())
                target = model(images_val[:, 3:, :, :])
                target_nhwc = target.transpose(1, 2).transpose(2, 3)
                pred = target_nhwc.data.cpu()
                gt = labels_val.cpu()
                l1loss = loss_fn(target_nhwc, labels_val)
                rloss, ssim, uworg, uwpred = reconst_loss(
                    images_val[:, :-1, :, :], target_nhwc, labels_val)
                val_l1loss += float(l1loss.cpu())
                val_rloss += float(rloss.cpu())
                val_ssimloss += float(ssim.cpu())
                val_mse += float(MSE(pred, gt))
            if args.tboard:
                show_unwarp_tnsboard(epoch + 1, writer, uwpred, uworg, 8,
                                     'Val GT unwarp', 'Val Pred Unwarp')

        val_l1loss = val_l1loss / len(valloader)
        val_mse = val_mse / len(valloader)
        val_ssimloss = val_ssimloss / len(valloader)
        val_rloss = val_rloss / len(valloader)
        print("val loss at epoch {}:: {}".format(epoch + 1, val_l1loss))
        print("val mse: {}".format(val_mse))
        val_losses = [val_l1loss, val_mse, val_rloss, val_ssimloss]
        write_log_file(log_file_name, val_losses, epoch + 1, lrate, 'Val')
        if args.tboard:
            # log the val losses
            writer.add_scalar('BM: L1 Loss/val', val_l1loss, epoch + 1)
            writer.add_scalar('CB: Recon Loss/val', val_rloss, epoch + 1)
            writer.add_scalar('CB: SSIM Loss/val', val_ssimloss, epoch + 1)

        #reduce learning rate
        sched.step(val_mse)

        if val_mse < best_val_mse:
            best_val_mse = val_mse
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_best_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))

        if (epoch + 1) % 10 == 0:
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))
Exemple #2
0
            """

        epoch_end_time = time.time() - epoch_start_time
        logger.info(
            "Epoch {}/{}, Training Loss {:.4f}, Classify Loss: {:.4f}, Domain Loss: {:.4f}, Time/Epoch: {:.4f}"
            .format(epoch, total_epoch, running_loss / train_dataset_size,
                    running_cls_loss / train_dataset_size,
                    running_domain_loss / train_dataset_size, epoch_end_time))
        train_loss.append(running_loss / train_dataset_size)
        writer.add_scalar('Training_Loss', running_loss / train_dataset_size,
                          epoch + 1)
        writer.add_scalar('Classfication_Loss',
                          running_cls_loss / train_dataset_size, epoch + 1)
        writer.add_scalar('Domain_Loss',
                          running_domain_loss / train_dataset_size, epoch + 1)
        writer.add_scalar('Learning_Rate', get_lr(optimizer), epoch + 1)

        # scheduler_warmup.step(epoch, metrics=(running_loss/train_dataset_size))

        ## =========================
        # model validation
        ## =========================
        if epoch % val_model_epoch == 0:
            logger.info("Validating model...")
            validation_loss = 0.0
            validation_corrects = 0

            mdan.eval()
            criterion = nn.CrossEntropyLoss()

            dataloader = validate_dataloader[target]
Exemple #3
0
def main():
    # SET THE PARAMETERS
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='Initial learning rate (default: 1e-3)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='Maximum number of epochs (default: 100)')
    parser.add_argument('--patience', type=int, default=10,
                        help='lr scheduler patience (default: 10)')
    parser.add_argument('--batch', type=int, default=4,
                        help='Batch size (default: 4)')
    parser.add_argument('--name', type=str, default='Prueba',
                        help='Name of the current test (default: Prueba)')

    parser.add_argument('--load_model', type=str, default='best_acc',
                        help='Weights to load (default: best_acc)')
    parser.add_argument('--test', action='store_false', default=True,
                        help='Only test the model')
    parser.add_argument('--resume', action='store_true', default=False,
                        help='Continue training a model')
    parser.add_argument('--load_path', type=str, default=None,
                        help='Name of the folder with the pretrained model')
    parser.add_argument('--ft', action='store_true', default=False,
                        help='Fine-tune a model')
    parser.add_argument('--psFactor', type=float, default=1,
                        help='Fine-tune a model')

    parser.add_argument('--gpu', type=str, default='0',
                        help='GPU(s) to use (default: 0)')
    args = parser.parse_args()

    training = args.test
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if args.ft:
        args.resume = True

    args.patch_size = [int(128*args.psFactor), int(128*args.psFactor), int(96*args.psFactor)]
    args.num_classes = 2

    # PATHS AND DIRS
    save_path = os.path.join('TRAIN', args.name)
    out_path = os.path.join(save_path, 'Val')
    load_path = save_path
    if args.load_path is not None:
        load_path = os.path.join('TRAIN/', args.load_path)

    root = '../../Data/Heart'
    train_file = 'train_paths.csv'
    test_file = 'val_paths.csv'

    if not os.path.exists(save_path):
        os.makedirs(save_path)
        os.makedirs(out_path)

    # SEEDS
    np.random.seed(12345)
    torch.manual_seed(12345)

    cudnn.deterministic = False
    cudnn.benchmark = True

    # CREATE THE NETWORK ARCHITECTURE
    model = GNet(num_classes=args.num_classes, backbone='xception')
    print('---> Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    model = model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr,
                           weight_decay=1e-5, amsgrad=True)

    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    annealing = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=args.patience, threshold=0.001,
        factor=0.5, threshold_mode="abs")

    criterion = utils.segmentation_loss(alpha=1)
    metrics = utils.Evaluator(args.num_classes)

    # LOAD A MODEL IF NEEDED (TESTING OR CONTINUE TRAINING)
    args.epoch = 0
    best_acc = 0
    if args.resume or not training:
        name = 'epoch_' + args.load_model + '.pth.tar'
        checkpoint = torch.load(
            os.path.join(load_path, name),
            map_location=lambda storage, loc: storage)
        args.epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        args.lr = checkpoint['lr']

        print('Loading model and optimizer {}.'.format(args.epoch))

        amp.load_state_dict(checkpoint['amp'])
        model.load_state_dict(checkpoint['state_dict'], strict=(not args.ft))
        if not args.ft:
            optimizer.load_state_dict(checkpoint['optimizer'])

    # DATALOADERS
    train_loader = Read_data.MRIdataset(True, train_file, root,
                                        args.patch_size)
    val_loader = Read_data.MRIdataset(True, test_file, root, args.patch_size,
                                      val=True)
    test_loader = Read_data.MRIdataset(False, test_file, root, args.patch_size)

    train_loader = DataLoader(train_loader, shuffle=True, sampler=None,
                              batch_size=args.batch, num_workers=10)
    val_loader = DataLoader(val_loader, shuffle=False, sampler=None,
                            batch_size=args.batch * 2, num_workers=10)
    test_loader = DataLoader(test_loader, shuffle=False, sampler=None,
                             batch_size=1, num_workers=0)

    # TRAIN THE MODEL
    is_best = True
    if training:
        torch.cuda.empty_cache()
        out_file = open(os.path.join(save_path, 'progress.csv'), 'a+')

        for epoch in range(args.epoch + 1, args.epochs + 1):
            args.epoch = epoch
            lr = utils.get_lr(optimizer)
            print('--------- Starting Epoch {} --> {} ---------'.format(
                epoch, time.strftime("%H:%M:%S")))
            print('Learning rate:', lr)

            train_loss = train(args, model, train_loader, optimizer, criterion)
            val_loss, acc = val(args, model, val_loader, criterion, metrics)

            acc = acc.item()
            out_file.write('{},{},{},{}\n'.format(
                args.epoch, train_loss, val_loss, acc))
            out_file.flush()

            annealing.step(val_loss)
            save_graph(save_path)

            is_best = best_acc < acc
            best_acc = max(best_acc, acc)

            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict(),
                'loss': [train_loss, val_loss],
                'lr': lr,
                'acc': acc,
                'best_acc': best_acc}

            checkpoint = epoch % 20 == 0
            utils.save_epoch(state, save_path, epoch,
                             checkpoint=checkpoint, is_best=is_best)

            if lr <= (args.lr / (10 ** 3)):
                print('Stopping training: learning rate is too small')
                break
        out_file.close()

    # TEST THE MODEL
    if not is_best:
        checkpoint = torch.load(
            os.path.join(save_path, 'epoch_best_acc.pth.tar'),
            map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        args.epoch = checkpoint['epoch']
        print('Testing epoch with best dice ({}: acc {})'.format(
            args.epoch, checkpoint['acc']))

    test(args, model, test_loader, out_path, test_file)
Exemple #4
0
def train_epoch(epoch,
                data_loader1,
                data_loader2,
                model,
                criterion,
                optimizer,
                device,
                current_lr,
                epoch_logger,
                batch_logger,
                is_master_node,
                tb_writer=None,
                distributed=False):

    print('train at epoch {}'.format(epoch))

    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = AverageMeter()

    #

    accuracies = AverageMeter()

    end_time = time.time()
    # data_loader = data_loader1+data_loader2
    # for i,(inputs,targets) in enumerate(data_loader):
    # for i,(data1,data2) in enumerate(zip(data_loader1,data_loader2)):
    dataloader_iterator = iter(data_loader2)
    for i, data1 in enumerate(data_loader1):

        try:
            data2 = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(data_loader2)
            data2 = next(dataloader_iterator)

        data_time.update(time.time() - end_time)

        inputs1, targets1 = data1
        inputs2, targets2 = data2
        inputs = torch.cat((inputs1, inputs2), 0)

        targets = torch.cat((targets1, targets2), 0)

        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()
        itera = (epoch - 1) * int(len(data_loader1)) + (i + 1)
        batch_lr = get_lr(optimizer)
        if is_master_node:
            if tb_writer is not None:

                tb_writer.add_scalar('train_iter/loss_iter', losses.val, itera)
                tb_writer.add_scalar('train_iter/acc_iter', accuracies.val,
                                     itera)
                tb_writer.add_scalar('train_iter/lr_iter', batch_lr, itera)

        if batch_logger is not None:
            batch_logger.log({
                'epoch': epoch,
                'batch': i + 1,
                'iter': itera,
                'loss': losses.val,
                'acc': accuracies.val,
                'lr': current_lr
            })

        local_rank = 0
        if is_master_node:
            print('Train Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc {acc.val:.3f} ({acc.avg:.3f})\t'
                  'RANK {rank}'.format(epoch,
                                       i + 1,
                                       len(data_loader1),
                                       batch_time=batch_time,
                                       data_time=data_time,
                                       loss=losses,
                                       acc=accuracies,
                                       rank=local_rank))

        if distributed:
            loss_sum = torch.tensor([losses.sum],
                                    dtype=torch.float32,
                                    device=device)
            loss_count = torch.tensor([losses.count],
                                      dtype=torch.float32,
                                      device=device)
            acc_sum = torch.tensor([accuracies.sum],
                                   dtype=torch.float32,
                                   device=device)
            acc_count = torch.tensor([accuracies.count],
                                     dtype=torch.float32,
                                     device=device)

            dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
            dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

            losses.avg = loss_sum.item() / loss_count.item()
            accuracies.avg = acc_sum.item() / acc_count.item()

    if epoch_logger is not None:
        epoch_logger.log({
            'epoch': epoch,
            'loss': losses.avg,
            'acc': accuracies.avg,
            'lr': current_lr,
            'rank': local_rank
        })
    if is_master_node:
        if tb_writer is not None:
            tb_writer.add_scalar('train/loss', losses.avg, epoch)
            tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
            tb_writer.add_scalar('train/lr', current_lr, epoch)
Exemple #5
0
import tensorflow as tf

from data_loader import get_batch_fn
from model import create_model
import utils

model = create_model(True)

optimizer = tf.keras.optimizers.Adam(utils.get_lr(0))
loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

loss_metric = tf.keras.metrics.Mean("loss_metric")
accuracy_metric = tf.keras.metrics.CategoricalAccuracy(name="accuracy_metric")
epoch_loss_metric = tf.keras.metrics.Mean("epoch_loss_metric")
epoch_accuracy_metric = tf.keras.metrics.CategoricalAccuracy(name="epoch_accuracy_metric")

ckpt = tf.train.Checkpoint(model=model)
ckpt_manger = tf.train.CheckpointManager(ckpt, utils.ckpt_path, max_to_keep=5)

if ckpt_manger.latest_checkpoint:
    ckpt.restore(ckpt_manger.latest_checkpoint)
    print('Latest checkpoint restored: {}'.format(ckpt_manger.latest_checkpoint))


@tf.function(experimental_relax_shapes=True)
def train_step(images, maps, keys):
    with tf.GradientTape() as tape:
        pred = model([images, maps], training=True)

        images_losses = loss_obj(y_true=keys, y_pred=pred[0])
        combined_losses = loss_obj(y_true=keys, y_pred=pred[1])
Exemple #6
0
def main_worker(index, opt):
    random.seed(opt.manual_seed)
    np.random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)

    if index >= 0 and opt.device.type == 'cuda':
        opt.device = torch.device(f'cuda:{index}')

    opt.is_master_node = not opt.distributed or opt.dist_rank == 0

    model = generate_model(opt)
    print('after generating model:', model.fc.in_features, ':',
          model.fc.out_features)
    print('feature weights:', model.fc.weight.shape, ':', model.fc.bias.shape)

    if opt.resume_path is not None:
        model = resume_model(opt.resume_path, opt.arch, model)
    print('after resume model:', model.fc.in_features, ':',
          model.fc.out_features)
    print('feature weights:', model.fc.weight.shape, ':', model.fc.bias.shape)
    # summary(model, input_size=(3, 112, 112))
    #    if opt.pretrain_path:
    #        model = load_pretrained_model(model, opt.pretrain_path, opt.model,
    #                                      opt.n_finetune_classes)

    print('after pretrained  model:', model.fc.in_features, ':',
          model.fc.out_features)
    print('feature weights:', model.fc.weight.shape, ':', model.fc.bias.shape)
    print(torch_summarize(model))
    # parameters = model.parameters()
    # for name, param in model.named_parameters():
    #     if param.requires_grad:
    #         print(name, param.data)
    #    summary(model, (3, 112, 112))
    #    return

    #    print('model parameters shape', parameters.shape)

    (train_loader, train_sampler, train_logger, train_batch_logger, optimizer,
     scheduler) = get_train_utils(opt, model.parameters())

    for i, (inputs, targets) in enumerate(train_loader):
        print('input shape:', inputs.shape)
        print('targets shape:', targets.shape)
        outputs = model(inputs)
        print("output shape", outputs.shape)
        model_arch = make_dot(outputs, params=dict(model.named_parameters()))
        print(model_arch)
        model_arch.render("/apollo/data/model.png", format="png")
        # Source(model_arch).render('/apollo/data/model.png')
        # print("generating /apollo/data/model.png")
        break

    # make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

    return

    if opt.batchnorm_sync:
        assert opt.distributed, 'SyncBatchNorm only supports DistributedDataParallel.'
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if opt.pretrain_path:
        model = load_pretrained_model(model, opt.pretrain_path, opt.model,
                                      opt.n_finetune_classes)
    if opt.resume_path is not None:
        model = resume_model(opt.resume_path, opt.arch, model)
    model = make_data_parallel(model, opt.distributed, opt.device)

    if opt.pretrain_path:
        parameters = get_fine_tuning_parameters(model, opt.ft_begin_module)
    else:
        parameters = model.parameters()

    if opt.is_master_node:
        print(model)

    criterion = CrossEntropyLoss().to(opt.device)

    if not opt.no_train:
        (train_loader, train_sampler, train_logger, train_batch_logger,
         optimizer, scheduler) = get_train_utils(opt, parameters)
        if opt.resume_path is not None:
            opt.begin_epoch, optimizer, scheduler = resume_train_utils(
                opt.resume_path, opt.begin_epoch, optimizer, scheduler)
            if opt.overwrite_milestones:
                scheduler.milestones = opt.multistep_milestones
    if not opt.no_val:
        val_loader, val_logger = get_val_utils(opt)

    if opt.tensorboard and opt.is_master_node:
        from torch.utils.tensorboard import SummaryWriter
        if opt.begin_epoch == 1:
            tb_writer = SummaryWriter(log_dir=opt.result_path)
        else:
            tb_writer = SummaryWriter(log_dir=opt.result_path,
                                      purge_step=opt.begin_epoch)
    else:
        tb_writer = None

    prev_val_loss = None
    for i in range(opt.begin_epoch, opt.n_epochs + 1):
        if not opt.no_train:
            if opt.distributed:
                train_sampler.set_epoch(i)
            current_lr = get_lr(optimizer)
            train_epoch(i, train_loader, model, criterion, optimizer,
                        opt.device, current_lr, train_logger,
                        train_batch_logger, tb_writer, opt.distributed)

            if i % opt.checkpoint == 0 and opt.is_master_node:
                save_file_path = opt.result_path / 'save_{}.pth'.format(i)
                save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
                                scheduler)

        if not opt.no_val:
            prev_val_loss = val_epoch(i, val_loader, model, criterion,
                                      opt.device, val_logger, tb_writer,
                                      opt.distributed)

        if not opt.no_train and opt.lr_scheduler == 'multistep':
            scheduler.step()
        elif not opt.no_train and opt.lr_scheduler == 'plateau':
            scheduler.step(prev_val_loss)

    if opt.inference:
        inference_loader, inference_class_names = get_inference_utils(opt)
        inference_result_path = opt.result_path / '{}.json'.format(
            opt.inference_subset)

        inference.inference(inference_loader, model, inference_result_path,
                            inference_class_names, opt.inference_no_average,
                            opt.output_topk)
Exemple #7
0
def train(args):

    # Setup Dataloader
    data_loader = get_loader('doc3dwc')
    data_path = args.data_path
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           augmentations=False)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Model
    model = get_model(args.arch, n_classes, in_channels=3)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Activation
    htan = nn.Hardtanh(0, 1.0)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.l_rate,
                                 weight_decay=5e-4,
                                 amsgrad=True)

    # LR Scheduler
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.5,
                                                       patience=5,
                                                       verbose=True)

    # Losses
    MSE = nn.MSELoss()
    loss_fn = nn.L1Loss()
    gloss = grad_loss.Gradloss(window_size=5, padding=2)

    epoch_start = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            # optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            epoch_start = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    # Log file:
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    # activation_dataset_lossparams_augmentations_trainstart
    experiment_name = 'htan_doc3d_l1grad_bghsaugk_scratch'
    log_file_name = os.path.join(args.logdir, experiment_name + '.txt')
    if os.path.isfile(log_file_name):
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w+')

    log_file.write('\n---------------  ' + experiment_name +
                   '  ---------------\n')
    log_file.close()

    # Setup tensorboard for visualization
    if args.tboard:
        # save logs in runs/<experiment_name>
        writer = SummaryWriter(comment=experiment_name)

    best_val_mse = 99999.0
    global_step = 0

    for epoch in range(epoch_start, args.n_epoch):
        avg_loss = 0.0
        avg_l1loss = 0.0
        avg_gloss = 0.0
        train_mse = 0.0
        model.train()

        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            pred = htan(outputs)
            g_loss = gloss(pred, labels)
            l1loss = loss_fn(pred, labels)
            loss = l1loss  # +(0.2*g_loss)
            avg_l1loss += float(l1loss)
            avg_gloss += float(g_loss)
            avg_loss += float(loss)
            train_mse += float(MSE(pred, labels).item())

            loss.backward()
            optimizer.step()
            global_step += 1

            if (i + 1) % 10 == 0:
                print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, i + 1, len(trainloader),
                       avg_loss / 10.0))
                avg_loss = 0.0

            if args.tboard and (i + 1) % 10 == 0:
                show_wc_tnsboard(global_step, writer, images, labels, pred, 8,
                                 'Train Inputs', 'Train WCs',
                                 'Train Pred. WCs')
                writer.add_scalars(
                    'Train', {
                        'WC_L1 Loss/train': avg_l1loss / (i + 1),
                        'WC_Grad Loss/train': avg_gloss / (i + 1)
                    }, global_step)

        train_mse = train_mse / len(trainloader)
        avg_l1loss = avg_l1loss / len(trainloader)
        avg_gloss = avg_gloss / len(trainloader)
        print("Training L1:%4f" % (avg_l1loss))
        print("Training MSE:'{}'".format(train_mse))
        train_losses = [avg_l1loss, train_mse, avg_gloss]

        lrate = get_lr(optimizer)
        write_log_file(experiment_name, train_losses, epoch + 1, lrate,
                       'Train')

        model.eval()
        val_loss = 0.0
        val_mse = 0.0
        val_bg = 0.0
        val_fg = 0.0
        val_gloss = 0.0
        val_dloss = 0.0
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            with torch.no_grad():
                images_val = Variable(images_val.cuda())
                labels_val = Variable(labels_val.cuda())

                outputs = model(images_val)
                pred_val = htan(outputs)
                g_loss = gloss(pred_val, labels_val).cpu()
                pred_val = pred_val.cpu()
                labels_val = labels_val.cpu()
                loss = loss_fn(pred_val, labels_val)
                val_loss += float(loss)
                val_mse += float(MSE(pred_val, labels_val))
                val_gloss += float(g_loss)

        val_loss = val_loss / len(valloader)
        val_mse = val_mse / len(valloader)
        val_gloss = val_gloss / len(valloader)
        print("val loss at epoch {}:: {}".format(epoch + 1, val_loss))
        print("val MSE: {}".format(val_mse))

        if args.tboard:
            show_wc_tnsboard(epoch + 1, writer, images_val, labels_val, pred,
                             8, 'Val Inputs', 'Val WCs', 'Val Pred. WCs')
            writer.add_scalars('L1', {
                'L1_Loss/train': avg_l1loss,
                'L1_Loss/val': val_loss
            }, epoch + 1)
            writer.add_scalars('GLoss', {
                'Grad Loss/train': avg_gloss,
                'Grad Loss/val': val_gloss
            }, epoch + 1)

        val_losses = [val_loss, val_mse, val_gloss]
        write_log_file(experiment_name, val_losses, epoch + 1, lrate, 'Val')

        # reduce learning rate
        sched.step(val_mse)

        if val_mse < best_val_mse:
            best_val_mse = val_mse
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_best_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))

        if (epoch + 1) % 10 == 0:
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))
Exemple #8
0
def main_worker(index, opt):
    random.seed(opt.manual_seed)
    np.random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)

    if index >= 0 and opt.device.type == 'cuda':
        opt.device = torch.device(f'cuda:{index}')

    if opt.distributed:
        opt.dist_rank = opt.dist_rank * opt.ngpus_per_node + index
        dist.init_process_group(backend='nccl',
                                init_method=opt.dist_url,
                                world_size=opt.world_size,
                                rank=opt.dist_rank)
        opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
        opt.n_threads = int(
            (opt.n_threads + opt.ngpus_per_node - 1) / opt.ngpus_per_node)
    opt.is_master_node = not opt.distributed or opt.dist_rank == 0
    model = generate_model(opt)
    if opt.batchnorm_sync:
        assert opt.distributed, 'SyncBatchNorm only supports DistributedDataParallel.'
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if opt.pretrain_path:
        model = load_pretrained_model(model, opt.pretrain_path, opt.model,
                                      opt.n_finetune_classes)
    if opt.resume_path is not None:
        model = resume_model(opt.resume_path, opt.arch, model)
        print('resume model from ', opt.resume_path)

    print('model after resume:', model)

    # save model to current running id
    # mlflow.pytorch.log_model(model, "action_model")
    # model_path = mlflow.get_artifact_uri("action_model")
    # print('mlflow action model path: ', model_path)
    # model = mlflow.pytorch.load_model(model_path)
    if opt.ml_tag_name != '' and opt.ml_tag_value != '':
        # mlflow.set_tag("test_tag", 'inference_test')
        mlflow.set_tag(opt.ml_tag_name, opt.ml_tag_value)

    # load from previous published model version
    if opt.ml_model_name != '' and opt.ml_model_version != '':
        # model_name = 'action_model'
        # model_version = '1'
        model_uri = "models:/{}/{}".format(opt.ml_model_name,
                                           opt.ml_model_version)
        model = mlflow.pytorch.load_model(model_uri)

    model = make_data_parallel(model, opt.distributed, opt.device)

    if opt.pretrain_path:
        parameters = get_fine_tuning_parameters(model, opt.ft_begin_module)
    else:
        parameters = model.parameters()

    if opt.is_master_node:
        print(model)

    criterion = CrossEntropyLoss().to(opt.device)

    if not opt.no_train:
        (train_loader, train_sampler, train_logger, train_batch_logger,
         optimizer, scheduler) = get_train_utils(opt, parameters)
        if opt.resume_path is not None:
            opt.begin_epoch, optimizer, scheduler = resume_train_utils(
                opt.resume_path, opt.begin_epoch, optimizer, scheduler)
            if opt.overwrite_milestones:
                scheduler.milestones = opt.multistep_milestones
    if not opt.no_val:
        val_loader, val_logger = get_val_utils(opt)

    if opt.tensorboard and opt.is_master_node:
        from torch.utils.tensorboard import SummaryWriter
        if opt.begin_epoch == 1:
            tb_writer = SummaryWriter(log_dir=opt.result_path)
        else:
            tb_writer = SummaryWriter(log_dir=opt.result_path,
                                      purge_step=opt.begin_epoch)
    else:
        tb_writer = None

    prev_val_loss = None
    for i in range(opt.begin_epoch, opt.n_epochs + 1):
        if not opt.no_train:
            if opt.distributed:
                train_sampler.set_epoch(i)
            current_lr = get_lr(optimizer)
            train_epoch(i, train_loader, model, criterion, optimizer,
                        opt.device, current_lr, train_logger,
                        train_batch_logger, tb_writer, opt.distributed)

            if i % opt.checkpoint == 0 and opt.is_master_node:
                save_file_path = opt.result_path / 'save_{}.pth'.format(i)
                save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
                                scheduler)
                if opt.ml_model_name != '':
                    mlflow.pytorch.log_model(model, opt.ml_model_name)

        if not opt.no_val:
            prev_val_loss = val_epoch(i, val_loader, model, criterion,
                                      opt.device, val_logger, tb_writer,
                                      opt.distributed)
            mlflow.log_metric("loss", prev_val_loss)

        if not opt.no_train and opt.lr_scheduler == 'multistep':
            scheduler.step()
        elif not opt.no_train and opt.lr_scheduler == 'plateau':
            scheduler.step(prev_val_loss)

    if opt.inference:
        inference_loader, inference_class_names = get_inference_utils(opt)
        inference_result_path = opt.result_path / '{}.json'.format(
            opt.inference_subset)

        inference.inference(inference_loader, model, inference_result_path,
                            inference_class_names, opt.inference_no_average,
                            opt.output_topk)
Exemple #9
0
def train_model(train_df,
                train_images,
                test_df,
                test_images,
                base_model,
                criterion,
                log,
                device,
                exp_dir,
                fold=0,
                num_epoch=1,
                mask_epoch=1):

    ds = val_split(train_df, train_images, fold=fold)
    learn_start = time()

    log.info('classification learning start')
    log.info("-" * 20)
    model = base_model.to(device)
    # log.info(model)
    log.info(f'parameters {count_parameter(model)}')
    best_model_wts = copy.deepcopy(model.state_dict())
    best_recall = 0

    # Observe that all parameters are being optimized
    log.info('Optimizer: Adam')
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=conf.init_lr)  #, weight_decay=1e-5)

    log.info(f"Scheduler: CosineLR, period={conf.period}")
    train_ds, val_ds, train_images, val_images = ds['train'], ds['val'], ds[
        'train_images'], ds['val_images']

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'max',
        patience=20,
        threshold=0.001,
        threshold_mode="abs",
    )

    for epoch in range(42, num_epoch):
        try:
            start = time()

            _, train_res = train(
                model,
                optimizer,  # scheduler, 
                train_ds,
                train_images,
                train_transform,
                device,
                criterion,
                epoch=epoch)

            clf_loss = train_res['loss']
            val_preds, val_res = validate(model, val_ds, val_images,
                                          valid_transform, device, criterion)
            val_clf = val_res['loss']
            val_recall = val_res['recall']

            calc_time = time() - start
            accum_time = time() - learn_start
            lr = get_lr(optimizer)

            log_msg = f"{epoch}\t{calc_time:.2f}\t{accum_time:.1f}\t{lr:.4f}\t"
            log_msg += f"{clf_loss:.4f}\t"

            train_recall = train_res['recall']
            log_msg += f"{train_recall:.4f}\t"
            log_msg += f"{val_clf:.4f}\t{val_recall:.4f}\t"
            log.info(log_msg)
            scheduler.step(val_recall)

            if val_recall > best_recall:
                best_model_wts = copy.deepcopy(model.state_dict())
                best_recall = val_recall
                best_val_preds = val_preds
                torch.save(model.state_dict(), exp_dir / f'model_{fold}.pkl')
                np.save(exp_dir / f'val_preds_{fold}.npy', val_preds)

        except KeyboardInterrupt:
            break

    log.info("-" * 20)
    log.info('Best val Recall: {:4f}'.format(best_recall))

    # load best model weights
    model.load_state_dict(best_model_wts)
    # test_preds = predict(model, test_df, test_images, valid_transform,
    #                      device)

    return model, best_val_preds  # , test_preds
Exemple #10
0
        if arg.CUDA:
            points, target = points.cuda(), target.cuda()

        # No partial last batches, in order to reduce noise in gradient.
        if len(target) != arg.batch_size:
            break

        # Forward and backward pass
        prediction = model(points)
        if not arg.classification:
            prediction = prediction.view(-1)
        loss = train_loss_func(prediction, target)
        avg_train_score += loss
        loss.backward()
        print('E: %02d - %02d/%02d - LR: %.6f - Loss: %.5f' %
              (epoch + 1, i + 1, num_batch, get_lr(optimizer)[0], loss),
              flush=True,
              end='\r')

        # Stepping
        optimizer.step()
        if arg.optimizer == 'SGD_cos':
            scheduler.step()

    # This section runs at the end of each batch
    test_score, x1, y1 = evaluateModel(model,
                                       test_loss_func,
                                       testloader,
                                       arg.dual,
                                       arg.CUDA,
                                       classification=arg.classification)
Exemple #11
0
def train(epoch, loss_fn):
    global tr_global_step, best_loss, best_iou, best_dice, best_acc, start_epoch
    writer.add_scalar('train/learning_rate', utils.get_lr(optimizer), epoch)
    model.train()
    torch.set_grad_enabled(True)
    optimizer.zero_grad()

    running_loss, running_iou, running_dice, running_acc = 0.0, 0.0, 0.0, 0.0
    it, total = 0, 0
    #pbar_disable = False if epoch == start_epoch else None
    pbar = tqdm(train_dataloader, unit="images", unit_scale=train_dataloader.batch_size, desc='Train: epoch {}'.format(epoch))
    for batch in pbar:
        inputs, targets = batch['input'], batch['target']
        #print("input shape: {}, output shape: {}".format(inputs.shape, targets.shape))
        inputs = inputs.float().cuda()
        targets = targets.cuda()

        # forward
        logits = model(inputs)
        logits = logits.squeeze(1)
        targets = targets.squeeze(1)

        probs = torch.sigmoid(logits)
        loss = loss_fn(logits, targets)
        # accumulate gradients
        loss = loss / args.grad_accumulation    
        loss.backward()
        if (it + 1) % args.grad_accumulation == 0:
            optimizer.step()
            optimizer.zero_grad()

        # statistics
        it += 1
        tr_global_step += 1
        loss = loss.item()
        running_loss += (loss * targets.size(0))
        total += targets.size(0)

        # writer.add_scalar('train/loss', loss, global_step)
        inputs_numpy = inputs.cpu().numpy()
        targets_numpy = targets.cpu().numpy()
        probs_numpy = probs.cpu().detach().numpy()
        predictions_numpy = probs_numpy > 0.5  # predictions.cpu().numpy()

        running_iou += iou_score(targets_numpy, predictions_numpy).sum()
        running_dice += dice_score(targets_numpy, predictions_numpy, noise_th=noise_th).sum()
        running_acc += accuracy_score(targets_numpy, predictions_numpy).sum()
        
        # update the progress bar
        pbar.set_postfix({
            'loss': "{:.05f}".format(running_loss / total),
            'IoU': "{:.03f}".format(running_iou / total),
            'Dice': "{:.03f}".format(running_dice / total),
            'Acc': "{:.03f}".format(running_acc / total)
        })

    epoch_loss = running_loss / total
    epoch_iou = running_iou / total
    epoch_dice = running_dice / total
    epoch_acc = running_acc / total
    writer.add_scalar('train/loss', epoch_loss, epoch)
    writer.add_scalar('train/iou', epoch_iou, epoch)
    writer.add_scalar('train/dice', epoch_dice, epoch)
    writer.add_scalar('train/accuracy', epoch_acc, epoch)

    return epoch_loss, epoch_iou, epoch_dice, epoch_acc
def test(epoch, round, dataset, shifted_dataset=None, test=False):          # test: when true, for final test/validation; when false, for selecting best model during training
    global best_validation_loss, best_model_state, best_acc

    net.eval()
    test_loss = 0
    accuracy_loss = 0
    distribution_mismatch_loss = 0
    correct = 0
    total = 0
    acc_per_class = PredictionAccPerClass()

    testloader = torch.utils.data.DataLoader(dataset, batch_size=200, shuffle=False, num_workers=2)
    regularizer = ProjectionCriterion()
    if shifted_dataset is not None:     # testing regularizd model
        testloader_shifted = torch.utils.data.DataLoader(shifted_dataset, batch_size=200, shuffle=False, num_workers=2)
        with torch.no_grad():
            for batch_idx, (test_data, shifted_test_data) in enumerate(zip(testloader, testloader_shifted)):
                test_input = test_data[0]
                targets = test_data[1]
                shifted_test_input = shifted_test_data[0]
                test_input, targets, shifted_test_input = test_input.to(device), targets.to(device), shifted_test_input.to(device)

                test_outputs = net(test_input)
                shifted_test_outputs = net(shifted_test_input)

                # another criterion
                if args.std_projection:
                    regularizer_loss = regularizer.compute_projection_criterion_withstd(test_outputs, shifted_test_outputs,
                                                                                batch_idx)
                # elif args.running_average:
                #     regularizer_loss = regularizer.compute_projection_criterion(test_outputs, shifted_test_outputs,
                #                                                                 batch_idx)
                else:
                    regularizer_loss = regularizer.compute_projection_criterion_simple(test_outputs,
                                                                                       shifted_test_outputs, batch_idx)

                loss = criterion(test_outputs, targets) + args.projection_weight * regularizer_loss

                test_loss += loss.item()
                accuracy_loss += criterion(test_outputs, targets).item()
                distribution_mismatch_loss += args.projection_weight * regularizer_loss
                _, predicted = test_outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                acc_per_class.update(predicted, targets)

            print("Accuracy + regularization loss %.3f , acc %.3f%% (%d/%d)"
                  % (test_loss/(batch_idx + 1), 100.*correct/total, correct, total))
            print("Accuracy loss %.3f"
                  % (accuracy_loss / (batch_idx + 1)))
            print("Distribution mismatch loss %.3f"
                  % (distribution_mismatch_loss / (batch_idx + 1)))
    else:                       # testing non-regularized model
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                acc_per_class.update(predicted, targets)
            print("Accuracy loss %.3f , acc %.3f%% (%d/%d)"
                  % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    acc_per_class.output_class_prediction()
    if not test:
        if test_loss < best_validation_loss:
            best_model_state = net.state_dict()
            print('Saving in round %s' % round)
            state = {
                'net': best_model_state,
                'validation_loss': test_loss,
                'epoch': epoch,
                'round': round,
                'learning_rate': get_lr(epoch, args.lr),
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt_concat_%s.pth' % root_name)
            best_validation_loss = test_loss
            return False
        else:
            print("Loss not reduced, decrease rate count: %s" % decrease_rate_count)
            return True
        trainloader_shifted = torch.utils.data.DataLoader(training_set_shifted,
                                                          batch_size=int(50000/len(trainloader)),
                                                          shuffle=True, num_workers=2)

    # Model
    print('==> Building model..')
    net = net_list[aug]
    net = net.to(device)


    end_epoch = np.max([args.epoch, start_epoch+1])
    decrease_rate_count = 0
    stop_count = 0

    for epoch in range(start_epoch, end_epoch):
        optimizer = optim.SGD(net.parameters(), lr=get_lr(epoch, args.lr), momentum=args.momentum, weight_decay=0.001) # 0.0005, momentum 0.9, lr 0.001
        if reg_list[aug]:
            train(epoch, zip(trainloader, trainloader_shifted), reg=True)
            no_improve = test(epoch, aug, validation_set, validation_set_shifted, test=False)
        else:
            train(epoch, trainloader, reg=False)
            no_improve = test(epoch, aug, validation_set, test=False)  # test False means model might be saved

        if no_improve:
            decrease_rate_count += 1
        else:
            decrease_rate_count = 0
        if decrease_rate_count >= 32:
            break

    # Question: where to extract confident samples?
Exemple #14
0
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        total_loss = 0
        total_cls_loss = 0
        total_reg_loss = 0

        self.model.train()
        self.train_metrics.reset()
        for batch_idx, data in enumerate(self.data_loader):
            # image, target = Variable(data['data'].float().cuda()), Variable(data['bbox'].float().cuda())
            # print(type(data['data']))
            # print(type(data['bbox'][0]))
            image = Variable(data['data'].float().cuda())
            with torch.no_grad():
                target = [Variable(i.float().cuda()) for i in data['bbox']]

            # Linear Learning Rate Warm-up
            full_batch_idx = ((epoch - 1) * len(self.data_loader) + batch_idx)
            if epoch - 1 < self.warm_up:
                for params in self.optimizer.param_groups:
                    params['lr'] = self.init_lr / (
                        self.warm_up * len(self.data_loader)) * full_batch_idx
            lr = get_lr(self.optimizer)

            # -------- TRAINING LOOP --------
            self.optimizer.zero_grad()
            output = self.model(image)
            reg_loss, cls_loss = self.criterion(output, target)
            loss = reg_loss + cls_loss
            loss.backward()
            self.optimizer.step()
            # -------------------------------

            total_loss += loss.item()
            total_cls_loss += cls_loss.item()
            total_reg_loss += reg_loss.item()

            # self.train_metrics.update('loss', loss.item())
            # for met in self.metric_ftns:
            #     self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                break

            # log = self.train_metrics.result()
            wandb_log = {
                'loss': loss.item(),
                'cls_loss': cls_loss.item(),
                'reg_loss': reg_loss.item()
            }

            # Add log to WandB
            if not self.config['debug']:
                wandb.log(wandb_log)

        # log = self.train_metrics.result()
        log = {
            'loss': total_loss / self.len_epoch,
            'cls_loss': total_cls_loss / self.len_epoch,
            'reg_loss': total_reg_loss / self.len_epoch
        }

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        log.update({'lr': lr})

        return log
Exemple #15
0
def main():
    # print("Starting DFC2021 baseline training script at %s" % (str(datetime.datetime.now())))
    #-------------------
    # Setup
    #-------------------
    assert os.path.exists(args.train_fn)
    assert os.path.exists(args.valid_fn)

    now_time = datetime.datetime.now()
    time_str = datetime.datetime.strftime(now_time, '%m-%d_%H-%M-%S')
    # output path
    # output_dir = Path(args.output_dir).parent / time_str / Path(args.output_dir).stem
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    logger = utils.init_logger(output_dir / 'info.log')
    # if os.path.isfile(args.output_dir):
    #     print("A file was passed as `--output_dir`, please pass a directory!")
    #     return
    #
    # if os.path.exists(args.output_dir) and len(os.listdir(args.output_dir)):
    #     if args.overwrite:
    #         print("WARNING! The output directory, %s, already exists, we might overwrite data in it!" % (args.output_dir))
    #     else:
    #         print("The output directory, %s, already exists and isn't empty. We don't want to overwrite and existing results, exiting..." % (args.output_dir))
    #         return
    # else:
    #     print("The output directory doesn't exist or is empty.")
    #     os.makedirs(args.output_dir, exist_ok=True)

    if args.gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    n_gpu = torch.cuda.device_count()
    device = torch.device('cuda:0' if n_gpu > 0 else 'cpu')
    device_ids = list(range(n_gpu))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    #-------------------
    # Load input data
    #-------------------

    train_dataframe = pd.read_csv(args.train_fn)
    train_image_fns = train_dataframe["image_fn"].values
    train_label_fns = train_dataframe["label_fn"].values
    train_groups = train_dataframe["group"].values
    train_dataset = StreamingGeospatialDataset(
        imagery_fns=train_image_fns, label_fns=train_label_fns, groups=train_groups, chip_size=CHIP_SIZE,
        num_chips_per_tile=NUM_CHIPS_PER_TILE, transform=transform, nodata_check=nodata_check
    )

    valid_dataframe = pd.read_csv(args.valid_fn)
    valid_image_fns = valid_dataframe["image_fn"].values
    valid_label_fns = valid_dataframe["label_fn"].values
    valid_groups = valid_dataframe["group"].values
    valid_dataset = StreamingValidationDataset(
        imagery_fns=valid_image_fns, label_fns=valid_label_fns, groups=valid_groups, chip_size=CHIP_SIZE,
        stride=CHIP_SIZE, transform=transform, nodata_check=nodata_check
    )

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    num_training_images_per_epoch = int(len(train_image_fns) * NUM_CHIPS_PER_TILE)
    # print("We will be training with %d batches per epoch" % (num_training_batches_per_epoch))

    #-------------------
    # Setup training
    #-------------------
    # if args.model == "unet":
    #     model = models.get_unet()
    # elif args.model == "fcn":
    #     model = models.get_fcn()
    # else:
    #     raise ValueError("Invalid model")

    model = models.isCNN(args.backbone)

    weights_init(model, seed=args.seed)

    model = model.to(device)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.AdamW(trainable_params, lr=INIT_LR, amsgrad=True, weight_decay=5e-4)
    lr_criterion = nn.CrossEntropyLoss(ignore_index=0) # todo
    hr_criterion = hr_loss
    # criterion = balanced_ce_loss
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.5, patience=3, min_lr=0.0000001)
    # factor=0.5, patience=3, min_lr=0.0000001
    logger.info("Trainable parameters: {}".format(utils.count_parameters(model)))

    #-------------------
    # Model training
    #-------------------
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    best_loss = 1e50
    num_times_lr_dropped = 0
    # model_checkpoints = []
    # temp_model_fn = os.path.join(output_dir, "most_recent_model.pt")

    for epoch in range(args.num_epochs):
        lr = utils.get_lr(optimizer)

        train_loss_epoch, valid_loss_epoch = utils.fit(
            model,
            device,
            train_dataloader,
            valid_dataloader,
            num_training_images_per_epoch,
            optimizer,
            lr_criterion,
            hr_criterion,
            epoch,
            logger)

        scheduler.step(valid_loss_epoch)

        if epoch % config.SAVE_PERIOD == 0 and epoch != 0:
            temp_model_fn = output_dir / 'checkpoint-epoch{}.pth'.format(epoch+1)
            torch.save(model.state_dict(), temp_model_fn)

        if valid_loss_epoch < best_loss:
            logger.info("Saving model_best.pth...")
            temp_model_fn = output_dir / 'model_best.pth'
            torch.save(model.state_dict(), temp_model_fn)
            best_loss = valid_loss_epoch

        if utils.get_lr(optimizer) < lr:
            num_times_lr_dropped += 1
            print("")
            print("Learning rate dropped")
            print("")

        train_loss_total_epochs.append(train_loss_epoch)
        valid_loss_total_epochs.append(valid_loss_epoch)
        epoch_lr.append(lr)
Exemple #16
0
def main_worker(index, opt):
    random.seed(opt.manual_seed)
    np.random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)

    if index >= 0 and opt.device.type == 'cuda':
        opt.device = torch.device(f'cuda:{index}')

    if opt.distributed:
        opt.dist_rank = opt.dist_rank * opt.ngpus_per_node + index
        dist.init_process_group(backend='nccl',
                                init_method=opt.dist_url,
                                world_size=opt.world_size,
                                rank=opt.dist_rank)
        opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
        # opt.n_threads = int(
        #     (opt.n_threads + opt.ngpus_per_node - 1) / opt.ngpus_per_node)
    opt.is_master_node = not opt.distributed or opt.dist_rank == 0


    model = genarate_model(opt)     
    if opt.batchnorm_sync:
        assert opt.distributed, 'SyncBatchNorm only supports DistributedDataParallel.'
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if opt.distributed:
        model = make_data_parallel(model,opt.device)
    else:
        model.to(opt.device)
        # model = nn.DataParallel(model).cuda()

    print('Total params: %.2fM' % (sum(p.numel()
                                       for p in model.parameters()) / 1000000.0))
    if opt.is_master_node:
        print(model)
    parameters = model.parameters()
    criterion = CrossEntropyLoss().to(opt.device)

    (train_loader, train_sampler, train_logger, train_batch_logger,
         optimizer, scheduler) = get_train_utils(opt, parameters)

    val_loader, val_logger = get_val_utils(opt)

    if not opt.tensorboard and opt.is_master_node:
        from torch.utils.tensorboard import SummaryWriter
        if opt.begin_epoch == 1:
            tb_writer = SummaryWriter(log_dir=opt.result_path)
        else:
            tb_writer = SummaryWriter(log_dir=opt.result_path,
                                      purge_step=opt.begin_epoch)
    else:
        tb_writer = None

    print('数据加载完毕')
    for i in range(opt.begin_epoch, opt.n_epochs + 1):
        if not opt.no_train:
            if opt.distributed:
                train_sampler.set_epoch(i)
                # train_sampler2.set_epoch(i)
            current_lr = get_lr(optimizer)
            train_epoch(i, train_loader, model, criterion, optimizer,
                        opt.device, current_lr, train_logger,
                        train_batch_logger, opt.is_master_node, tb_writer, opt.distributed)

            if i % opt.checkpoint == 0 and opt.is_master_node:
                save_file_path = opt.result_path / 'save_{}.pth'.format(i)
                save_checkpoint(save_file_path, i,model, optimizer,
                                scheduler)

        if not opt.no_val:
            prev_val_loss = val_epoch(i, val_loader, model, criterion,
                                      opt.device, val_logger,opt.is_master_node, tb_writer,
                                      opt.distributed)

        if not opt.no_train and opt.lr_scheduler == 'multistep':
            scheduler.step()
        elif not opt.no_train and opt.lr_scheduler == 'plateau':
            scheduler.step(prev_val_loss)
Exemple #17
0
def main():
    print("Starting DFC2021 baseline training script at %s" % (str(datetime.datetime.now())))


    #-------------------
    # Setup
    #-------------------
    assert os.path.exists(args.input_fn)

    if os.path.isfile(args.output_dir):
        print("A file was passed as `--output_dir`, please pass a directory!")
        return

    if os.path.exists(args.output_dir) and len(os.listdir(args.output_dir)):
        if args.overwrite:
            print("WARNING! The output directory, %s, already exists, we might overwrite data in it!" % (args.output_dir))
        else:
            print("The output directory, %s, already exists and isn't empty. We don't want to overwrite and existing results, exiting..." % (args.output_dir))
            return
    else:
        print("The output directory doesn't exist or is empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    if torch.cuda.is_available():
        device = torch.device("cuda:%d" % args.gpu)
    else:
        print("WARNING! Torch is reporting that CUDA isn't available, exiting...")
        return

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


    #-------------------
    # Load input data
    #-------------------
    input_dataframe = pd.read_csv(args.input_fn)
    image_fns = input_dataframe["image_fn"].values
    label_fns = input_dataframe["label_fn"].values
    groups = input_dataframe["group"].values

    dataset = StreamingGeospatialDataset(
        imagery_fns=image_fns, label_fns=label_fns, groups=groups, chip_size=CHIP_SIZE, num_chips_per_tile=NUM_CHIPS_PER_TILE, windowed_sampling=False, verbose=False,
        image_transform=image_transforms, label_transform=label_transforms, nodata_check=nodata_check
    )

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    num_training_batches_per_epoch = int(len(image_fns) * NUM_CHIPS_PER_TILE / args.batch_size)
    print("We will be training with %d batches per epoch" % (num_training_batches_per_epoch))


    #-------------------
    # Setup training
    #-------------------
    if args.model == "unet":
        model = models.get_unet()
    elif args.model == "fcn":
        model = models.get_fcn()
    else:
        raise ValueError("Invalid model")

    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, amsgrad=True)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")

    print("Model has %d parameters" % (utils.count_parameters(model)))


    #-------------------
    # Model training
    #-------------------
    training_task_losses = []
    num_times_lr_dropped = 0 
    model_checkpoints = []
    temp_model_fn = os.path.join(args.output_dir, "most_recent_model.pt")

    for epoch in range(args.num_epochs):
        lr = utils.get_lr(optimizer)

        training_losses = utils.fit(
            model,
            device,
            dataloader,
            num_training_batches_per_epoch,
            optimizer,
            criterion,
            epoch,
        )
        scheduler.step(training_losses[0])

        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        if args.save_most_recent:
            torch.save(model.state_dict(), temp_model_fn)

        if utils.get_lr(optimizer) < lr:
            num_times_lr_dropped += 1
            print("")
            print("Learning rate dropped")
            print("")
            
        training_task_losses.append(training_losses[0])
            
        if num_times_lr_dropped == 4:
            break


    #-------------------
    # Save everything
    #-------------------
    save_obj = {
        'args': args,
        'training_task_losses': training_task_losses,
        "checkpoints": model_checkpoints
    }

    save_obj_fn = "results.pt"
    with open(os.path.join(args.output_dir, save_obj_fn), 'wb') as f:
        torch.save(save_obj, f)
Exemple #18
0
def train_epoch(epoch,
                data_loader,
                model,
                criterion,
                optimizer,
                device,
                epoch_logger,
                batch_logger,
                scheduler,
                lr_scheduler,
                tb_writer=None,
                distributed=False):
    print('train at epoch {}'.format(epoch))

    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end_time = time.time()
    for i, (inputs, targets) in enumerate(data_loader):
        data_time.update(time.time() - end_time)

        targets = targets.to(device, non_blocking=True)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        acc = calculate_accuracy(outputs, targets)

        losses.update(loss.item(), inputs.size(0))
        accuracies.update(acc, inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if lr_scheduler == 'cosineannealingwarmrestart':
            scheduler.step()

        batch_time.update(time.time() - end_time)
        end_time = time.time()

        with torch.no_grad():
            current_lr = get_lr(optimizer)

        if batch_logger is not None:
            batch_logger.log({
                'epoch': epoch,
                'batch': i + 1,
                'iter': (epoch - 1) * len(data_loader) + (i + 1),
                'loss': losses.val,
                'acc': accuracies.val,
                'lr': current_lr
            })

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Acc {acc.val:.3f} ({acc.avg:.3f})\t'
              'Lr {lr:.6f}'.format(epoch,
                                   i + 1,
                                   len(data_loader),
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses,
                                   acc=accuracies,
                                   lr=current_lr).expandtabs(tabsize=4))

    if distributed:
        loss_sum = torch.tensor([losses.sum],
                                dtype=torch.float32,
                                device=device)
        loss_count = torch.tensor([losses.count],
                                  dtype=torch.float32,
                                  device=device)
        acc_sum = torch.tensor([accuracies.sum],
                               dtype=torch.float32,
                               device=device)
        acc_count = torch.tensor([accuracies.count],
                                 dtype=torch.float32,
                                 device=device)

        dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(loss_count, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(acc_count, op=dist.ReduceOp.SUM)

        losses.avg = loss_sum.item() / loss_count.item()
        accuracies.avg = acc_sum.item() / acc_count.item()

    if epoch_logger is not None:
        epoch_logger.log({
            'epoch': epoch,
            'loss': losses.avg,
            'acc': accuracies.avg,
            'lr': current_lr
        })

    if tb_writer is not None:
        tb_writer.add_scalar('train/loss', losses.avg, epoch)
        tb_writer.add_scalar('train/acc', accuracies.avg, epoch)
        tb_writer.add_scalar('train/lr', accuracies.avg, epoch)
Exemple #19
0
def train(model, num_gpus, output_directory, epochs, learning_rate, lr_decay_step, lr_decay_gamma,
          sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard):
    # local eval and synth functions
    def evaluate():
        # eval loop
        model.eval()
        epoch_eval_loss = 0
        for i, batch in enumerate(test_loader):
            with torch.no_grad():
                mel, audio = batch
                mel = torch.autograd.Variable(mel.cuda())
                audio = torch.autograd.Variable(audio.cuda())
                outputs = model(audio, mel)

                loss = criterion(outputs)
                if num_gpus > 1:
                    reduced_loss = loss.mean().item()
                else:
                    reduced_loss = loss.item()
                epoch_eval_loss += reduced_loss

        epoch_eval_loss = epoch_eval_loss / len(test_loader)
        print("EVAL {}:\t{:.9f}".format(iteration, epoch_eval_loss))
        if with_tensorboard:
            logger.add_scalar('eval_loss', epoch_eval_loss, iteration)
            logger.flush()
        model.train()

    def synthesize(sigma):
        model.eval()
        # synthesize loop
        for i, batch in enumerate(synth_loader):
            if i == 0:
                with torch.no_grad():
                    mel, _, filename = batch
                    mel = torch.autograd.Variable(mel.cuda())
                    try:
                        audio = model.reverse(mel, sigma)
                    except AttributeError:
                        audio = model.module.reverse(mel, sigma)
                    except NotImplementedError:
                        print("reverse not implemented for this model. skipping synthesize!")
                        model.train()
                        return

                    audio = audio * MAX_WAV_VALUE
                audio = audio.squeeze()
                audio = audio.cpu().numpy()
                audio = audio.astype('int16')
                audio_path = os.path.join(
                    os.path.join(output_directory, "samples", waveflow_config["model_name"]),
                    "generate_{}.wav".format(iteration))
                write(audio_path, data_config["sampling_rate"], audio)

        model.train()

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    criterion = WaveFlowLossDataParallel(sigma)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=lr_decay_step, gamma=lr_decay_gamma)

    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Load checkpoint if one exists
    iteration = 0
    if args.resume:
        model_directory = os.path.join(
            output_directory, waveflow_config["model_name"]
        )
        logging.info("--resume. Resuming the training from the last "
            "checkpoint found in {}.".format(model_directory))
        last_checkpoint = last_n_checkpoints(model_directory, 1)[0]
        model, optimizer, scheduler, iteration = \
            load_checkpoint(last_checkpoint, model, optimizer, scheduler)

    elif checkpoint_path != "":
        # Warm-start
        if args.warm_start and args.average_checkpoint == 0:
            print("INFO: --warm_start. optimizer and scheduler are initialized and strict=False for load_state_dict().")
            model, optimizer, scheduler, iteration = load_checkpoint_warm_start(
                    checkpoint_path, model, optimizer, scheduler)
        elif args.warm_start and args.average_checkpoint != 0:
            print("INFO: --average_checkpoint > 0. loading an averaged "
                  "weight of last {} checkpoints...".format(args.average_checkpoint))
            model, optimizer, scheduler, iteration = load_averaged_checkpoint_warm_start(
                checkpoint_path, model, optimizer, scheduler
            )
        else:
            model, optimizer, scheduler, iteration = \
                load_checkpoint(checkpoint_path, model, optimizer, scheduler)
        iteration += 1  # next iteration is iteration + 1

    if num_gpus > 1:
        print("num_gpus > 1. converting the model to DataParallel...")
        model = torch.nn.DataParallel(model)

    trainset = Mel2Samp("train", False, False, **data_config)
    train_loader = DataLoader(trainset, num_workers=4, shuffle=True,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)

    testset = Mel2Samp("test", False, False, **data_config)
    test_sampler = None
    test_loader = DataLoader(testset, num_workers=4, shuffle=False,
                             sampler=test_sampler,
                             batch_size=batch_size,
                             pin_memory=False,
                             drop_last=False)

    synthset = Mel2Samp("test", True, True, **data_config)
    synth_sampler = None
    synth_loader = DataLoader(synthset, num_workers=4, shuffle=False,
                              sampler=synth_sampler,
                              batch_size=1,
                              pin_memory=False,
                              drop_last=False)

    # Get shared output_directory ready
    if not os.path.isdir(os.path.join(output_directory, waveflow_config["model_name"])):
        os.makedirs(os.path.join(output_directory, waveflow_config["model_name"]), exist_ok=True)
        os.chmod(os.path.join(output_directory, waveflow_config["model_name"]), 0o775)
    print("output directory", os.path.join(output_directory, waveflow_config["model_name"]))
    if not os.path.isdir(os.path.join(output_directory, "samples")):
        os.makedirs(os.path.join(output_directory, "samples"), exist_ok=True)
        os.chmod(os.path.join(output_directory, "samples"), 0o775)
    os.makedirs(os.path.join(output_directory, "samples", waveflow_config["model_name"]), exist_ok=True)
    os.chmod(os.path.join(output_directory, "samples", waveflow_config["model_name"]), 0o775)

    if with_tensorboard:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter(os.path.join(output_directory, waveflow_config["model_name"], 'logs'))

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
            tic = time.time()

            model.zero_grad()

            mel, audio = batch
            mel = torch.autograd.Variable(mel.cuda())
            audio = torch.autograd.Variable(audio.cuda())
            outputs = model(audio, mel)

            loss = criterion(outputs)
            if num_gpus > 1:
                reduced_loss = loss.mean().item()
            else:
                reduced_loss = loss.item()

            if fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.mean().backward()

            if fp16_run:
                grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 5.)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()

            toc = time.time() - tic

            #print("{}:\t{:.9f}, {:.4f} seconds".format(iteration, reduced_loss, toc))
            if with_tensorboard:
                logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
                logger.add_scalar('lr', get_lr(optimizer), i + len(train_loader) * epoch)
                logger.add_scalar('grad_norm', grad_norm, i + len(train_loader) * epoch)
                logger.flush()

            if (iteration % iters_per_checkpoint == 0):
                checkpoint_path = "{}/waveflow_{}".format(
                    os.path.join(output_directory, waveflow_config["model_name"]), iteration)
                save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                                checkpoint_path)

                if iteration != 0:
                    evaluate()
                    del mel, audio, outputs, loss
                    gc.collect()
                    synthesize(sigma)

            iteration += 1
            scheduler.step()

        evaluate()
Exemple #20
0
def train(model,
          optimizer,
          loss_function,
          n_epochs,
          training_loader,
          validation_loader,
          model_filename,
          metric_to_monitor="val_loss",
          metric=None,
          early_stopping_patience=None,
          learning_rate_decay_patience=None,
          save_best=False,
          n_gpus=1,
          verbose=True,
          regularized=False,
          decay_factor=0.1,
          min_lr=0.,
          learning_rate_decay_step_size=None,
          pretrain=None):
    training_log = list()

    start_epoch = 0
    training_log_header = [
        "epoch", "loss", "lr", "val_loss", 'metric_score', 'metric_val_score'
    ]

    if learning_rate_decay_patience:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=learning_rate_decay_patience,
            verbose=verbose,
            factor=decay_factor,
            min_lr=min_lr)
    elif learning_rate_decay_step_size:
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=optimizer,
            step_size=learning_rate_decay_step_size,
            gamma=decay_factor,
            last_epoch=-1)
        # Setting the last epoch to anything other than -1 requires the optimizer that was previously used.
        # Since I don't save the optimizer, I have to manually step the scheduler the number of epochs that have already
        # been completed. Stepping the scheduler before the optimizer raises a warning, so I have added the below
        # code to step the scheduler and catch the UserWarning that would normally be thrown.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(start_epoch):
                scheduler.step()
    else:
        scheduler = None

    for epoch in range(start_epoch, n_epochs):

        if epoch < 10 and pretrain:
            for p in model.enc0.parameters():
                p.requires_grad = False
            for p in model.enc1.parameters():
                p.requires_grad = False
            for p in model.enc2.parameters():
                p.requires_grad = False
            for p in model.enc3.parameters():
                p.requires_grad = False
            for p in model.enc4.parameters():
                p.requires_grad = False
        else:
            for p in model.enc0.parameters():
                p.requires_grad = True
            for p in model.enc1.parameters():
                p.requires_grad = True
            for p in model.enc2.parameters():
                p.requires_grad = True
            for p in model.enc3.parameters():
                p.requires_grad = True
            for p in model.enc4.parameters():
                p.requires_grad = True

            # early stopping
            if (training_log and early_stopping_patience
                    and np.asarray(training_log)
                [:, training_log_header.index(metric_to_monitor)].argmin() <=
                    len(training_log) - early_stopping_patience):
                print("Early stopping patience {} has been reached.".format(
                    early_stopping_patience))
                break

            # train the model
            loss, metric_score = epoch_training(training_loader,
                                                model,
                                                loss_function,
                                                optimizer=optimizer,
                                                metric=metric,
                                                epoch=epoch,
                                                n_gpus=n_gpus,
                                                regularized=regularized)
            try:
                training_loader.dataset.on_epoch_end()
            except AttributeError:
                warnings.warn(
                    "'on_epoch_end' method not implemented for the {} dataset."
                    .format(type(training_loader.dataset)))

            # predict validation data
            if validation_loader:
                val_loss, metric_val_score = epoch_validatation(
                    validation_loader,
                    model,
                    loss_function,
                    metric=metric,
                    n_gpus=n_gpus,
                    regularized=regularized)
                metric_val_score = 1 - metric_val_score
            else:
                val_loss = None

            # update the training log
            training_log.append([
                epoch, loss,
                get_lr(optimizer), val_loss, metric_score, metric_val_score
            ])  # each epoch add results to training log
            pd.DataFrame(training_log, columns=training_log_header).set_index(
                "epoch").to_csv(model_filename + 'log.csv')

            min_epoch = np.asarray(
                training_log)[:,
                              training_log_header.index(metric_to_monitor
                                                        )].argmin()

            # check loss and decay
            if scheduler:
                if validation_loader and scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
                    scheduler.step(val_loss)  # case plateau on validation set
                elif scheduler.__class__ == torch.optim.lr_scheduler.ReduceLROnPlateau:
                    scheduler.step(loss)
                else:
                    scheduler.step()

            # save model
            torch.save(model.state_dict(), model_filename)
            if save_best and min_epoch == len(training_log) - 1:
                best_filename = model_filename.replace(".h5", "_best.h5")
                forced_copy(model_filename, best_filename)
Exemple #21
0
def main():
    print("MODEL ID: {}".format(C.model_id))

    summary_writer = SummaryWriter(C.log_dpath)

    train_iter, val_iter, test_iter, vocab = build_loaders()

    model = build_model(vocab)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=C.lr,
                                 weight_decay=C.weight_decay,
                                 amsgrad=True)
    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     mode='min',
                                     factor=C.lr_decay_gamma,
                                     patience=C.lr_decay_patience,
                                     verbose=True)

    try:
        best_val_CIDEr = 0.
        best_epoch = None
        best_ckpt_fpath = None
        for e in range(1, C.epochs + 1):
            print("\n\n\nEpoch {:d}".format(e))

            ckpt_fpath = C.ckpt_fpath_tpl.format(e)
            """ Train """
            print("\n[TRAIN]")
            train_loss = train(e, model, optimizer, train_iter, vocab,
                               C.decoder.rnn_teacher_forcing_ratio,
                               C.reg_lambda, C.gradient_clip)
            log_train(summary_writer, e, train_loss, get_lr(optimizer))
            """ Validation """
            print("\n[VAL]")
            val_loss = evaluate(model, val_iter, vocab, C.reg_lambda)
            val_scores, _, _ = score(model,
                                     val_iter,
                                     vocab,
                                     beam_width=5,
                                     beam_alpha=0.)
            log_val(summary_writer, e, val_loss, val_scores)

            if e >= C.save_from and e % C.save_every == 0:
                print("Saving checkpoint at epoch={} to {}".format(
                    e, ckpt_fpath))
                save_checkpoint(e, model, ckpt_fpath, C)

            if e >= C.lr_decay_start_from:
                lr_scheduler.step(val_loss['total'])
            if val_scores['CIDEr'] > best_val_CIDEr:
                best_epoch = e
                best_val_CIDEr = val_scores['CIDEr']
                best_ckpt_fpath = ckpt_fpath
    except KeyboardInterrupt:
        if e >= C.save_from:
            print("Saving checkpoint at epoch={}".format(e))
            save_checkpoint(e, model, ckpt_fpath, C)
        else:
            print("Do not save checkpoint at epoch={}".format(e))
    finally:
        """ Test with Best Model """
        print("\n\n\n[BEST]")
        best_model = load_checkpoint(model, best_ckpt_fpath)
        best_scores, _, _ = score(best_model,
                                  test_iter,
                                  vocab,
                                  beam_width=5,
                                  beam_alpha=0.)
        print("scores: {}".format(best_scores))
        for metric in C.metrics:
            summary_writer.add_scalar("BEST SCORE/{}".format(metric),
                                      best_scores[metric], best_epoch)
        save_checkpoint(e, best_model, C.ckpt_fpath_tpl.format("best"), C)
Exemple #22
0
    model.train()
    # use train_epoch_scale/eval_epoch_scale for training scale equivariant models
    train_mse.append(train_epoch(train_loader, model, optimizer, loss_fun))
    model.eval()
    mse, _, _ = eval_epoch(valid_loader, model, loss_fun)
    valid_mse.append(mse)

    if valid_mse[-1] < min_mse:
        min_mse = valid_mse[-1]
        best_model = model
        torch.save(best_model, save_name + ".pth")
    end = time.time()

    # Early Stopping but train at least for 50 epochs
    if (len(train_mse) > 50
            and np.mean(valid_mse[-5:]) >= np.mean(valid_mse[-10:-5])):
        break
    print(i + 1, train_mse[-1], valid_mse[-1], round((end - start) / 60, 5),
          format(get_lr(optimizer), "5.2e"))

test_mse, preds, trues, loss_curve = test_epoch(test_loader, best_model,
                                                loss_fun)
torch.save(
    {
        "preds": preds,
        "trues": trues,
        "test_mse": test_mse,
        "loss_curve": loss_curve
    }, name + ".pt")
Exemple #23
0
 def get_current_lrs(self):
     lrs = {
         f'learning_rate/{key}': get_lr(self.optimizers[key])
         for key in self.optimizers.keys()
     }
     return lrs
Exemple #24
0
def main():
    # Instantiate parser
    parser = ArgumentParser()

    parser.add_argument("--network_config", required=True)

    # Parse
    args = parser.parse_args()

    # Parse config
    config = utils.parse_config(config_fname=args.network_config)

    # Temporary files
    tmp_dir = utils.create_temp_dir(config_fname=args.network_config)

    # Device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # TensorBoard
    summary_writer = SummaryWriter(tmp_dir)

    # Test batch (for visualization purposes)
    test_batch = analysis.get_test_batch(batch_size=config["batch_size"],
                                         size=config["size"],
                                         circle_radius=config["circle_radius"],
                                         device=device)

    # Build network
    net = Network(config["network"],
                  in_height=config["size"],
                  in_width=config["size"]).to(device)

    # Optimizer
    optimizer = optim.SGD(net.parameters(),
                          lr=config["learning_rate"],
                          momentum=config["momentum"],
                          nesterov=config["nesterov"])

    # Learning rate schedule
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 gamma=config["lr_gamma"])

    # Loss functions
    criterion_radius = nn.L1Loss()

    def _criterion_phi(alpha, beta):
        """
        PyTorch training criterion for angle loss (mean of great circle distances)
        Range of values: [0, 1]
        :param alpha: A PyTorch Tensor specifying the first angles
        :param beta: A PyTorch Tensor specifying the second angles
        :return: A new PyTorch Tensor of same shape and dtype
        """
        return great_circle_distance(alpha, beta).mean()

    for global_step, batch in enumerate(
            gen(batch_size=config["batch_size"],
                size=config["size"],
                circle_radius=config["circle_radius"],
                device=device)):

        optimizer.zero_grad()
        radius_pred, phi_pred = net(batch.image)
        loss_radius = criterion_radius(radius_pred, batch.radius)
        loss_phi = _criterion_phi(phi_pred, batch.phi)
        loss = (loss_radius + loss_phi) / 2.
        loss.backward()
        optimizer.step()

        # Adjust learning rate
        if global_step % config["lr_step"] == 0:
            scheduler.step()

        # TensorBoard
        with torch.set_grad_enabled(mode=False):

            # Copy onto CPU
            radius_pred = radius_pred.cpu().numpy()
            phi_pred = phi_pred.cpu().numpy()

            summary_writer.add_scalar(tag="lr",
                                      scalar_value=utils.get_lr(optimizer),
                                      global_step=global_step)
            summary_writer.add_scalar(tag="l1_radius",
                                      scalar_value=loss_radius,
                                      global_step=global_step)
            summary_writer.add_scalar(tag="l1_phi",
                                      scalar_value=loss_phi,
                                      global_step=global_step)
            summary_writer.add_scalar(tag="l1",
                                      scalar_value=loss,
                                      global_step=global_step)

            summary_writer.add_histogram(tag="histogram_radius_pred",
                                         values=radius_pred,
                                         global_step=global_step)
            summary_writer.add_histogram(tag="histogram_phi_pred",
                                         values=phi_pred,
                                         global_step=global_step)

            # Draw
            if config["save_images"]:

                # Eval mode
                net = net.eval()

                # Copy onto CPU
                radius_test_pred, phi_test_pred = net(test_batch.image)
                radius_test_pred = radius_test_pred.cpu().numpy()
                phi_test_pred = phi_test_pred.cpu().numpy()

                analysis.draw_test_batch(
                    image_batch=test_batch.image.cpu().numpy(),
                    radius_batch=radius_test_pred,
                    phi_batch=phi_test_pred,
                    tmp_dir=tmp_dir,
                    global_step=global_step,
                    size=config["size"],
                    circle_radius=config["circle_radius"])

                # Training mode
                net = net.train()
Exemple #25
0
                    # update meter
                    meter_loss_rmse.update(np.sqrt(loss_mse.item()), B)
                    meter_loss_kp.update(loss_kp.item(), B)
                    meter_loss_H.update(loss_H.item(), B)
                    meter_loss.update(loss.item(), B)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if i % args.log_per_iter == 0:
                log = '%s [%d/%d][%d/%d] LR: %.6f' % (
                    phase, epoch, args.n_epoch, i, data_n_batches[phase],
                    get_lr(optimizer))

                if args.stage == 'dy':
                    log += ', kp: %.6f (%.6f), H: %.6f (%.6f)' % (loss_kp.item(
                    ), meter_loss_kp.avg, loss_H.item(), meter_loss_H.avg)

                    log += ' [%d' % num_edge_per_type[0]
                    for tt in range(1, args.edge_type_num):
                        log += ', %d' % num_edge_per_type[tt]
                    log += ']'

                    log += ', rmse: %.6f (%.6f)' % (np.sqrt(
                        loss_mse.item()), meter_loss_rmse.avg)

                    if args.env in ['Ball']:
                        log += ', acc: %.4f (%.4f)' % (permu_edge_acc,
Exemple #26
0
def train():
    # input params
    set_seed(GLOBAL_SEED)
    config = getConfig()
    data_config = getDatasetConfig(config.dataset)
    sw_log = 'logs/%s' % config.dataset
    sw = SummaryWriter(log_dir=sw_log)
    best_prec1 = 0.
    rate = 0.875

    # define train_dataset and loader
    transform_train = transforms.Compose([
        transforms.Resize(
            (int(config.input_size // rate), int(config.input_size // rate))),
        transforms.RandomCrop((config.input_size, config.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=32. / 255., saturation=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    train_dataset = CustomDataset(data_config['train'],
                                  data_config['train_root'],
                                  transform=transform_train)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.workers,
                              pin_memory=True,
                              worker_init_fn=_init_fn)

    transform_test = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_dataset = CustomDataset(data_config['val'],
                                data_config['val_root'],
                                transform=transform_test)
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True,
                            worker_init_fn=_init_fn)
    # logging dataset info
    print('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.
          format(dataset_name=config.dataset,
                 train_num=len(train_dataset),
                 val_num=len(val_dataset)))
    print('Batch Size:[{0}], Total:::Train Batches:[{1}],Val Batches:[{2}]'.
          format(config.batch_size, len(train_loader), len(val_loader)))
    # define model
    if config.model_name == 'inception':
        net = inception_v3_bap(pretrained=True,
                               aux_logits=False,
                               num_parts=config.parts)
    elif config.model_name == 'resnet50':
        net = resnet50(pretrained=True, use_bap=True)

    in_features = net.fc_new.in_features
    new_linear = torch.nn.Linear(in_features=in_features,
                                 out_features=train_dataset.num_classes)
    net.fc_new = new_linear
    # feature center
    feature_len = 768 if config.model_name == 'inception' else 512
    center_dict = {
        'center':
        torch.zeros(train_dataset.num_classes, feature_len * config.parts)
    }

    # gpu config
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
        center_dict['center'] = center_dict['center'].cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and config.multi_gpu:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)

    # define optimizer
    assert config.optim in ['sgd', 'adam'], 'optim name not found!'
    if config.optim == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=config.lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optim == 'adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=config.lr,
                                     weight_decay=config.weight_decay)

    # define learning scheduler
    assert config.scheduler in ['plateau',
                                'step'], 'scheduler not supported!!!'
    if config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               patience=3,
                                                               factor=0.1)
    elif config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=2,
                                                    gamma=0.9)

    # define loss
    criterion = torch.nn.CrossEntropyLoss()
    if use_gpu:
        criterion = criterion.cuda()

    # train val parameters dict
    state = {
        'model': net,
        'train_loader': train_loader,
        'val_loader': val_loader,
        'criterion': criterion,
        'center': center_dict['center'],
        'config': config,
        'optimizer': optimizer
    }
    ## train and val
    engine = Engine()
    print(config)
    for e in range(config.epochs):
        if config.scheduler == 'step':
            scheduler.step()
        lr_val = get_lr(optimizer)
        print("Start epoch %d ==========,lr=%f" % (e, lr_val))
        train_prec, train_loss = engine.train(state, e)
        prec1, val_loss = engine.validate(state)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': e + 1,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'center': center_dict['center']
            }, is_best, config.checkpoint_path)
        sw.add_scalars("Accurancy", {'train': train_prec, 'val': prec1}, e)
        sw.add_scalars("Loss", {'train': train_loss, 'val': val_loss}, e)
        if config.scheduler == 'plateau':
            scheduler.step(val_loss)
Exemple #27
0
                        result_render = vis.render(
                            result[idx_frame], lim_low=0., lim_high=0.4, text='selfsupervised')

                        frame = vis.merge_frames(
                            [touch_raw_render, touch_render, result_render], nx=1, ny=3)

                        out.write(frame)

            running_sample += B
            running_loss += loss.item() * B
            running_loss_scale += loss_scale.item() * B
            running_loss_recon += loss_recon.item() * B

            if i % args.log_per_iter == 0:
                print('[%d/%d][%d/%d] LR: %.6f, loss: %.4f (%.4f), scale: %.4f (%.4f), recon: %.4f (%.4f)' % (
                    epoch, args.n_epoch, i, len(dataloaders[phase]), get_lr(optimizer),
                    loss.item(), running_loss / running_sample,
                    loss_scale.item(), running_loss_scale / running_sample,
                    loss_recon.item(), running_loss_recon / running_sample))

            if i > 0 and i % args.ckp_per_iter == 0 and phase == 'train':
                model_path = '%s/net_epoch_%d_iter_%d.pth' % (args.ckp_path, epoch, i)
                torch.save(model.state_dict(), model_path)


        loss_cur = running_loss / running_sample
        loss_cur_scale = running_loss_scale / running_sample
        loss_cur_recon = running_loss_recon / running_sample
        print('[%d/%d %s] loss: %.4f, scale: %.4f, recon: %.4f, best_valid_loss: %.4f' % (
            epoch, args.n_epoch, phase, loss_cur, loss_cur_scale, loss_cur_recon, best_valid))
        print('[%d/%d %s] min_value: %.4f, max_value: %.4f, gBias: %.4f, gScale: %.4f' % (
def train(
    batch_size,
    n_epochs,
    rnn_type,
    bidir,
    n_layers,
    hidden_dim,
    embedding_dim,
    teacher_forcing_ratio,
    src_vocab_size,
    tgt_vocab_size,
    learning_rate,
    dropout_p,
    train_dataloader,
    val_dataloader,
    metric=None,
    device=DEVICE,
    savename="model",
    pretrained_emb=None,
    is_resume=False,
    **kwargs,
):

    model = nn.DataParallel(
        NMT(
            src_vocab_size,
            tgt_vocab_size,
            embedding_dim,
            hidden_dim,
            n_layers,
            bidir,
            dropout_p,
            rnn_type.lower(),
            pretrained_emb,
        ).to(device))
    print("num of vocab:", src_vocab_size)

    start_epoch = 0
    model.module.init_weight()

    criterion = nn.NLLLoss(ignore_index=train_dataloader.pad_id)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="min",
                                                           factor=0.1,
                                                           patience=5)
    if is_resume:
        checkpoint = torch.load(os.path.join("save", "model_checkpoint.pth"))
        model.module.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        save_checkpoint.best = checkpoint["bleu"]

        start_epoch = checkpoint["epoch"] + 1
        for p in model.parameters():
            p.requires_grad = True

    for epoch in range(start_epoch, n_epochs):
        running_loss = 0.0
        running_total = 0
        n_predict_words = 0
        model.train()
        for i, (train_X, len_X, train_y, len_y) in enumerate(train_dataloader):
            train_X = train_X.to(DEVICE)
            train_y = train_y.to(DEVICE)
            hidden = [
                h.to(device)
                for h in model.module.init_hidden(train_X.shape[0])
            ]

            log_p = model(train_X, train_y, len_X, hidden,
                          teacher_forcing_ratio)
            # remove <sos>
            loss = criterion(log_p, train_y[:, 1:].reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           5.0)  # gradient clipping
            optimizer.step()

            # statistics
            running_total += train_X.shape[0]
            running_loss += loss.item()
            n_predict_words += len_y.sum().item() - len_y.shape[0]
            _loss = running_loss / (i + 1)

            progress_bar(
                running_total,
                train_dataloader.n_examples,
                epoch + 1,
                n_epochs,
                {
                    "loss": _loss,
                    "lr": get_lr(optimizer),
                    "ppl": evaluate_ppl(running_loss, n_predict_words),
                },
            )

        WRITER.add_scalar("train_loss", _loss, epoch + 1)
        _score = validate(val_dataloader, model, criterion, epoch, device)
        scheduler.step(_loss)
        save_checkpoint(
            {
                "net": model.module.state_dict(),
                "epoch": epoch,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "itos": train_dataloader.itos,
            },
            "bleu",
            _score,
            epoch,
            savename,
        )
Exemple #29
0
def main_worker(index, opt):
    random.seed(opt.manual_seed)
    np.random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)

    if index >= 0 and opt.device.type == 'cuda':
        opt.device = torch.device(f'cuda:{index}')

    if opt.distributed:
        opt.dist_rank = opt.dist_rank * opt.ngpus_per_node + index
        dist.init_process_group(backend='nccl',
                                init_method=opt.dist_url,
                                world_size=opt.world_size,
                                rank=opt.dist_rank)
        opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
        opt.n_threads = int(
            (opt.n_threads + opt.ngpus_per_node - 1) / opt.ngpus_per_node)
    opt.is_master_node = not opt.distributed or opt.dist_rank == 0

    model = generate_model(opt)
    if opt.batchnorm_sync:
        assert opt.distributed, 'SyncBatchNorm only supports DistributedDataParallel.'
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if opt.pretrain_path:
        model = load_pretrained_model(model, opt.pretrain_path, opt.model,
                                      opt.n_finetune_classes, opt.device)
    if opt.resume_path is not None:
        model = resume_model(opt.resume_path, opt.arch, model, opt.device)
    model = make_data_parallel(model, opt.distributed, opt.device)

    if opt.pretrain_path:
        parameters = get_fine_tuning_parameters(model, opt.ft_begin_module)
    else:
        parameters = model.parameters()

    if opt.is_master_node:
        print(model)

    criterion = CrossEntropyLoss().to(opt.device)

    if not opt.no_train:
        (train_loader, train_sampler, train_logger, train_batch_logger,
         optimizer, scheduler) = get_train_utils(opt, parameters)
        if opt.resume_path is not None:
            opt.begin_epoch, optimizer, scheduler = resume_train_utils(
                opt.resume_path, opt.begin_epoch, optimizer, scheduler)
            if opt.overwrite_milestones:
                scheduler.milestones = opt.multistep_milestones
    if not opt.no_val:
        val_loader, val_logger = get_val_utils(opt)

    if opt.tensorboard and opt.is_master_node:
        # from torch.utils.tensorboard import SummaryWriter
        from tensorboardX import SummaryWriter
        if opt.begin_epoch == 1:
            tb_writer = SummaryWriter(log_dir=opt.result_path)
        else:
            tb_writer = SummaryWriter(log_dir=opt.result_path,
                                      purge_step=opt.begin_epoch)
    else:
        tb_writer = None

    prev_val_loss = None
    for i in range(opt.begin_epoch, opt.n_epochs + 1):
        if not opt.no_train:
            if opt.distributed:
                train_sampler.set_epoch(i)
            current_lr = get_lr(optimizer)
            train_epoch(i, train_loader, model, criterion, optimizer,
                        opt.device, current_lr, train_logger,
                        train_batch_logger, tb_writer, opt.distributed)

            if i % opt.checkpoint == 0 and opt.is_master_node:
                save_file_path = opt.result_path / 'save_{}.pth'.format(i)
                save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
                                scheduler)

        if not opt.no_val:
            prev_val_loss = val_epoch(i, val_loader, model, criterion,
                                      opt.device, val_logger, tb_writer,
                                      opt.distributed)

        if not opt.no_train and opt.lr_scheduler == 'multistep':
            scheduler.step()
        elif not opt.no_train and opt.lr_scheduler == 'plateau':
            scheduler.step(prev_val_loss)

    if opt.inference:
        inference_loader, inference_class_names = get_inference_utils(opt)
        inference_result_path = opt.result_path / '{}.json'.format(
            opt.inference_subset)

        inference_results = inference.inference(inference_loader, model,
                                                inference_result_path,
                                                inference_class_names,
                                                opt.inference_no_average,
                                                opt.output_topk, opt.device)
        return inference_results
    return {}
Exemple #30
0
             (batch_idx + 1), 100. * correct / total, correct, total))
        acc_per_class.output_class_prediction()

    acc = 100. * correct / total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'origin_acc': acc if not shift else best_acc_original,
            'shifted_acc': acc if shift else best_acc_shifted,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_noreg.pth')
        if shift:
            best_acc_shifted = acc
        else:
            best_acc_original = acc


for epoch in range(start_epoch, start_epoch + 200):
    optimizer = optim.SGD(net.parameters(),
                          lr=get_lr(epoch, args.lr),
                          momentum=0.5,
                          weight_decay=5e-4)
    train(epoch)
    test(epoch, shift=False)
    test(epoch, shift=True)
    # TODO: pass the confident samples to train_loader