Ejemplo n.º 1
0
def do_validate(model, optimizer, criterion, metrics, scheduler, options,
                timeit):
    """Evaluate the model on the test dataset and save to the checkpoint."""
    # evaluate the model.
    metrics_values, loss = validate(model, optimizer, criterion, metrics,
                                    options)

    options.runtime['cumu_time_val'].append(timeit.cumu)

    if len(metrics_values) > 0:
        # Assume the first metric is used to determine the best model to checkpoint.
        prim_metric = metrics[0]
        prim_metric_value = metrics_values[prim_metric.name]

        is_best, best_metric_name = update_best_runtime_metric(
            options, prim_metric_value, prim_metric.name)

        log.log_val(options, best_metric_name)

        for name, value in metrics_values.items():
            log.post_metrics(options, name, value)
    else:
        is_best = False
        log.debug("Validation loss={:.3f}".format(loss), 0)

    log.post_metrics(options, 'val_loss', loss)

    checkpoint.save(options, model, optimizer, scheduler, is_best)
Ejemplo n.º 2
0
def do_validate(model, optimizer, criterion, metrics, scheduler, options,
                timeit):
    """Evaluate the model on the test dataset and save to the checkpoint."""
    # evaluate the model.
    metrics_values, loss = validate(model, optimizer, criterion, metrics,
                                    options)

    timeit.pause()

    if len(metrics_values) > 0:
        # Assume the first metric is used to determine the best model to checkpoint.
        prim_metric = metrics[0]
        prim_metric_value = metrics_values[0]

        is_best, best_metric_name = update_best_runtime_metric(
            options, prim_metric_value, prim_metric.name)

        checkpoint.save(options, model, optimizer, scheduler, is_best)
        log.log_val(options, best_metric_name)

        for metric, value in zip(metrics, metrics_values):
            log.post_metrics(options, metric.name, value)

    log.post_metrics(options, 'Validation Loss', loss)
    options.runtime['val_loss_hist'].append(loss)
    options.runtime['val_metrics_hist'].append(metrics_values)
    options.runtime['val_time'].append(timeit.cumu)
    timeit.resume()
Ejemplo n.º 3
0
def do_validate(model, optimizer, criterion, metrics, scheduler, options, timeit):
    """Evaluate the model on the test dataset and save to the checkpoint."""
    # evaluate the model.
    metrics_values, loss = validate(model, optimizer, criterion, metrics, options)

    options.runtime['cumu_time_val'].append(timeit.cumu)

    if len(metrics_values) > 0:
        # Assume the first metric is used to determine the best model to checkpoint.
        prim_metric = metrics[0]
        prim_metric_value = metrics_values[prim_metric.name]

        is_best, best_metric_name = update_best_runtime_metric(options, prim_metric_value, prim_metric.name)

        log.log_val(options, best_metric_name)

        for name, value in metrics_values.items():
            log.post_metrics(options, name, value)
    else:
        is_best = False
        log.debug("Validation loss={:.3f}".format(loss), 0)

    log.post_metrics(options, 'val_loss', loss)

    checkpoint.save(options, model, optimizer, scheduler, is_best)
Ejemplo n.º 4
0
def run(net, logger, hps):
    print("Running simple training loop")
    # Create dataloaders
    trainloader, valloader, testloader = get_dataloaders()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    print("Training", hps['name'], "on", device)
    for epoch in range(hps['start_epoch'], 100):
        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        if (epoch + 1) % 20 == 0:
            save(net, logger, hps, epoch + 1)
            logger.save_plt(hps)

        print('Epoch %2d' % (epoch + 1),
              'Train Accuracy: %2.2f %%' % acc_tr,
              'Val Accuracy: %2.2f %%' % acc_v,
              sep='\t\t')

    # Reduce Learning Rate
    print("Reducing learning rate from 0.001 to 0.0001")
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.0001

    # Train for 20 extra epochs
    for epoch in range(epoch, 120):
        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        if (epoch + 1) % 20 == 0:
            save(net, logger, hps, epoch + 1)
            logger.save_plt(hps)

        print('Epoch %2d' % (epoch + 1),
              'Train Accuracy: %2.2f %%' % acc_tr,
              'Val Accuracy: %2.2f %%' % acc_v,
              sep='\t\t')

    acc_test, loss_test = evaluate(net, testloader, criterion)
    print('Test Accuracy: %2.2f %%' % acc_test,
          'Test Loss: %2.6f %%' % loss_test,
          sep='\t\t')
def run(net):
    # Create dataloaders
    trainloader, valloader = prepare_data()

    net = net.to(device)

    optimizer = torch.optim.SGD(net.parameters(),
                                lr=hps['lr'],
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0001)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.5,
                                  patience=10,
                                  verbose=True)
    criterion = nn.CrossEntropyLoss()

    best_acc_v = 0

    print("Training", hps['name'], "on", device)
    for epoch in range(hps['n_epochs']):
        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        # Update learning rate if plateau
        scheduler.step(acc_v)

        # Save logs regularly
        if (epoch + 1) % 5 == 0:
            logger.save(hps)

        # Save the best network and print results
        if acc_v > best_acc_v:
            save(net, hps)
            best_acc_v = acc_v

            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Network Saved',
                  sep='\t\t')

        else:
            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  sep='\t\t')
def run(net):
    folds = prepare_folds()

    trainloader, valloader = folds[int(hps['fold_id'])]

    net = net.to(device)

    scaler = GradScaler()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    criterion = nn.CrossEntropyLoss().to(device)
    best_acc_v = 0

    print("Training", hps['name'], hps['fold_id'], "on", device)
    for epoch in range(hps['n_epochs']):

        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer, scaler,
                                epoch + 1)
        acc_v, loss_v = evaluate(net, valloader, criterion, epoch + 1)

        # Update learning rate if plateau
        scheduler.step()

        # Save the best network and print results
        if acc_v > best_acc_v:
            save(net, hps, desc=hps['fold_id'])
            best_acc_v = acc_v

            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Train Loss: %2.6f' % loss_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Val Loss: %2.6f' % loss_v,
                  'Network Saved',
                  sep='\t\t')

        else:
            print('Epoch %2d' % (epoch + 1),
                  'Train Accuracy: %2.2f %%' % acc_tr,
                  'Train Loss: %2.6f' % loss_tr,
                  'Val Accuracy: %2.2f %%' % acc_v,
                  'Val Loss: %2.6f' % loss_v,
                  sep='\t\t')
Ejemplo n.º 7
0
def run(net, logger, hps):
    # Create dataloaders
    print('start loading data')

    trainloader, valloader, testloader = get_dataloaders(bs=16)
    net = net.to(device)

    optimizer = pseudoInverse(net.parameters(), C=0.001, L=0)

    print("Training", hps['name'], "on", device)
    train(net, optimizer, trainloader)

    print("Saving Model")
    save(net, logger, hps, epoch=1)

    print("Evaluating", hps['name'])
    test(net, valloader)
Ejemplo n.º 8
0
def run(net, logger, hps):
    # Create dataloaders
    print('start loading data')

    trainloader, valloader, testloader = get_dataloaders()
    net = net.to(device)
    learning_rate = float(hps['lr'])
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0001)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.5,
                                  patience=10,
                                  verbose=True)
    criterion = nn.CrossEntropyLoss()
    print("Training", hps['name'], "on", device)

    for epoch in range(300):
        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        # Update learning rate if plateau
        scheduler.step(acc_v)

        if (epoch + 1) % hps['save_freq'] == 0:
            save(net, logger, hps, epoch + 1)
            logger.save_plt(hps)

        print('Epoch %2d' % (epoch + 1),
              'Train Accuracy: %2.2f %%' % acc_tr,
              'Val Accuracy: %2.2f %%' % acc_v,
              sep='\t\t')

    acc_test, loss_test = evaluate(net, testloader, criterion)
    print('Test Accuracy: %2.2f %%' % acc_test,
          'Test Loss: %2.6f %%' % loss_test,
          sep='\t\t')
Ejemplo n.º 9
0
def main(args):
    args = copy.deepcopy(args)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    update_args(args)

    distributed.init(args)
    args.device = torch.device("cuda" if use_cuda else "cpu")
    logger = Logger(args)
    logger.print(f"PyTorch version: {torch.__version__}")
    logger.print(f"PyTorch CUDA version: {torch.version.cuda}")
    logger.print(str(args))

    # load data
    train_data, val_data, test_data, corpus = data.get_data(
        args, logger, args.data_eos)
    if len(args.data_omit_labels) > 0:
        args.data_omit_label_idx = [
            corpus.dictionary.word2idx[w] for w in args.data_omit_labels
        ]
    else:
        args.data_omit_label_idx = None

    # create a model
    if args.feedback:
        model = feedback.FeedbackTransformer(args)
    elif args.expire_span:
        model = expire_span.ExpireSpan(args)
    elif args.compress:
        model = compressive.CompressiveTransformer(args)
    else:
        model = transformer_seq.TransformerSeq(args)
    model.to(args.device)

    # count params
    nparameters = 0
    params = []
    for param in model.parameters():
        if param.requires_grad:
            nparameters += param.numel()
            params.append(param)
    logger.print("nparameters={:.2f}M".format(nparameters / 1e6))

    # OPTIM param
    if args.optim == "sgd":
        optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum)
    elif args.optim == "adam":
        optimizer = optim.Adam(params, lr=args.lr)

    if args.lr_decay:
        # will do warm-up manually later
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.nepochs * args.nbatches)
    elif args.lr_warmup > 0:
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, lambda ep: min(1, ep / args.lr_warmup))
    else:
        scheduler = None

    model = distributed.wrap_model(args, model)

    ep_init = checkpoint.load(args, model, optimizer, logger, scheduler)

    # pos: data samling 0=sequential, -1=random
    pos = [0 for _ in range(3)]
    if isinstance(train_data, tuple):
        pos[0] = random.randrange(train_data[0].size(1) - args.mem_sz)
    else:
        pos[0] = random.randrange(train_data.size(1) - args.mem_sz)
    hid_cache = [
        model.module.init_hid_cache(args.batch_sz),
        model.module.init_hid_cache(args.test_batch_sz),
        model.module.init_hid_cache(args.test_batch_sz),
    ]

    if args.full_test:
        # perform evaluation only
        with torch.no_grad():
            stat_val, pos[1], hid_cache[1] = train(
                args,
                model,
                optimizer,
                scheduler,
                val_data,
                test_only=True,
                train_pos=pos[1],
                h_cache=hid_cache[1],
                corpus=corpus,
            )
            stat_test, pos[2], hid_cache[2] = train(
                args,
                model,
                optimizer,
                scheduler,
                test_data,
                test_only=True,
                train_pos=pos[2],
                h_cache=hid_cache[2],
                corpus=corpus,
            )
            gpu_mem = torch.cuda.max_memory_allocated() / 1024**3
            stat_test, stat_val, gpu_mem = distributed.collect_stat(
                args, stat_test, stat_val, gpu_mem)
            if args.data_type == "char":
                if "err" in stat_val:
                    logger.print("val err: {:.3f}%".format(stat_val["err"] *
                                                           100))
                    logger.print("test err: {:.3f}%".format(stat_test["err"] *
                                                            100))
                else:
                    logger.print("val: {:.3f}bpc".format(stat_val["loss"] /
                                                         math.log(2)))
                    logger.print("test: {:.3f}bpc".format(stat_test["loss"] /
                                                          math.log(2)))
            else:
                logger.print("val: {:.3f}ppl".format(math.exp(
                    stat_val["loss"])))
                logger.print("test: {:.3f}ppl".format(
                    math.exp(stat_test["loss"])))
            logger.print(f"gpu_mem: {gpu_mem:.1f}gb")
        return

    for ep in range(ep_init, args.nepochs):
        t_sta = time.time()
        args.ep = ep
        stat_train, pos[0], hid_cache[0] = train(
            args,
            model,
            optimizer,
            scheduler,
            train_data,
            train_pos=pos[0],
            h_cache=hid_cache[0],
            corpus=corpus,
        )
        elapsed = 1000 * (time.time() - t_sta) / args.nbatches
        with torch.no_grad():
            if args.full_valid:
                stat_val, _, _ = train(
                    args,
                    model,
                    optimizer,
                    scheduler,
                    val_data,
                    test_only=True,
                    train_pos=pos[1],
                    h_cache=hid_cache[1],
                    corpus=corpus,
                )
            else:
                stat_val, pos[1], hid_cache[1] = train(
                    args,
                    model,
                    optimizer,
                    scheduler,
                    val_data,
                    test_only=True,
                    train_pos=pos[1],
                    h_cache=hid_cache[1],
                    corpus=corpus,
                )

        gpu_mem = torch.cuda.max_memory_allocated() / 1024**3
        torch.cuda.reset_max_memory_allocated()
        stat_train, stat_val, gpu_mem = distributed.collect_stat(
            args, stat_train, stat_val, gpu_mem)

        if args.rank == 0:
            # only the master process will do logging, plotting and checkpoint
            if args.lr_decay:
                logger.log("compute/lr", optimizer.param_groups[0]["lr"])
            if args.adapt_span:
                adaptive_span.log(args, model, logger, stat_train)
            if args.expire_span:
                expire_span.log(args, model, logger, stat_train)
            if args.feedback:
                feedback.log(args, model, logger, stat_train)

            logger.step(args, stat_train, stat_val, elapsed, gpu_mem)
            checkpoint.save(args, model, optimizer, logger, scheduler)
def main(epoch_num, batch_size, lr, num_gpu, img_size, data_path, log_path,
         resume, eval_intvl, cp_intvl, vis_intvl, num_workers):
    data_path = Path(data_path)
    log_path = Path(log_path)
    cp_path = log_path / 'checkpoint'

    if not resume and log_path.exists() and len(list(log_path.glob('*'))) > 0:
        print(f'log path "{str(log_path)}" has old file', file=sys.stderr)
        sys.exit(-1)
    if not cp_path.exists():
        cp_path.mkdir(parents=True)

    transform = MedicalTransform(output_size=img_size,
                                 roi_error_range=15,
                                 use_roi=False)

    dataset = KiTS19(data_path,
                     stack_num=5,
                     spec_classes=[0, 1, 1],
                     img_size=img_size,
                     use_roi=False,
                     train_transform=transform,
                     valid_transform=transform)

    net = ResUNet(in_ch=dataset.img_channels,
                  out_ch=dataset.num_classes,
                  base_ch=64)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    start_epoch = 0
    if resume:
        data = {'net': net, 'optimizer': optimizer, 'epoch': 0}
        cp_file = Path(resume)
        cp.load_params(data, cp_file, device='cpu')
        start_epoch = data['epoch'] + 1

    criterion = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=5,
        verbose=True,
        threshold=0.0001,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-08)

    logger = SummaryWriter(str(log_path))

    gpu_ids = [i for i in range(num_gpu)]

    print(f'{" Start training ":-^40s}\n')
    msg = f'Net: {net.__class__.__name__}\n' + \
          f'Dataset: {dataset.__class__.__name__}\n' + \
          f'Epochs: {epoch_num}\n' + \
          f'Learning rate: {optimizer.param_groups[0]["lr"]}\n' + \
          f'Batch size: {batch_size}\n' + \
          f'Device: cuda{str(gpu_ids)}\n'
    print(msg)

    torch.cuda.empty_cache()

    # to GPU device
    net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()
    criterion = criterion.cuda()
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()

    # start training
    valid_score = 0.0
    best_score = 0.0
    best_epoch = 0

    for epoch in range(start_epoch, epoch_num):
        epoch_str = f' Epoch {epoch + 1}/{epoch_num} '
        print(f'{epoch_str:-^40s}')
        print(f'Learning rate: {optimizer.param_groups[0]["lr"]}')

        net.train()
        torch.set_grad_enabled(True)
        transform.train()
        try:
            loss = training(net, dataset, criterion, optimizer, scheduler,
                            epoch, batch_size, num_workers, vis_intvl, logger)

            if eval_intvl > 0 and (epoch + 1) % eval_intvl == 0:
                net.eval()
                torch.set_grad_enabled(False)
                transform.eval()

                train_score = evaluation(net,
                                         dataset,
                                         epoch,
                                         batch_size,
                                         num_workers,
                                         vis_intvl,
                                         logger,
                                         type='train')
                valid_score = evaluation(net,
                                         dataset,
                                         epoch,
                                         batch_size,
                                         num_workers,
                                         vis_intvl,
                                         logger,
                                         type='valid')

                print(f'Train data score: {train_score:.5f}')
                print(f'Valid data score: {valid_score:.5f}')

            if valid_score > best_score:
                best_score = valid_score
                best_epoch = epoch
                cp_file = cp_path / 'best.pth'
                cp.save(epoch, net.module, optimizer, str(cp_file))
                print('Update best acc!')
                logger.add_scalar('best/epoch', best_epoch + 1, 0)
                logger.add_scalar('best/score', best_score, 0)

            if (epoch + 1) % cp_intvl == 0:
                cp_file = cp_path / f'cp_{epoch + 1:03d}.pth'
                cp.save(epoch, net.module, optimizer, str(cp_file))

            print(f'Best epoch: {best_epoch + 1}')
            print(f'Best score: {best_score:.5f}')

        except KeyboardInterrupt:
            cp_file = cp_path / 'INTERRUPTED.pth'
            cp.save(epoch, net.module, optimizer, str(cp_file))
            return
Ejemplo n.º 11
0
def train(args):
    graph = tf.Graph()
    with graph.as_default():
        global_step = tf.train.create_global_step()
        with tf.name_scope("create_train_inputs"):
            reader = ImageReader(
                args.data_dir,
                args.train_list,
                (args.crop_height, args.crop_width),
                args.is_training,
            )
            image_batch, label_batch, mean_std_batch = reader.dequeue(
                args.batch_size)

        with tf.name_scope("create_test_inputs"):
            test_reader = ImageReader(
                args.data_dir,
                args.test_list,
                (args.crop_height, args.crop_width),
                False,
            )
            test_image, test_score, test_mean_std = test_reader.image, test_reader.score, test_reader.mean_std
            test_image, test_score, test_mean_std = tf.expand_dims(
                test_image,
                dim=0), tf.expand_dims(test_score,
                                       dim=0), tf.expand_dims(test_mean_std,
                                                              dim=0)
        ## placeholders for training data
        imgs = tf.placeholder(tf.float32,
                              [None, args.crop_height, args.crop_width, 3])
        scores = tf.placeholder(tf.float32, [None, 10])

        with tf.name_scope("create_models"):
            vgg = vgg16(imgs)
            x = fully_connection(vgg.pool5, 128, 0.5)
            scores_hat = tf.nn.softmax(x)

        means = tf.placeholder(tf.float32, [None, 1])
        with tf.name_scope("create_loss"):
            emd_loss_out = _emd(scores, scores_hat)
            mean_hat = scores_stats(scores_hat)
            l2_loss = reg_l2(means, mean_hat)
            loss = emd_loss_out + l2_loss * 0.0

        # decay_steps = len(reader.image_list) / args.batch_size
        # lr = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps,
        #                                            args.learning_rate_decay_factor, staircase=True,name='exponential_decay_learning_rate')
        lr = tf.placeholder(tf.float32, [])
        with tf.name_scope("create_optimize"):
            # optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(loss)
            # var_list = [v for v in tf.trainable_variables()]
            # print("--------------------------------")
            # print(var_list)
            # print("--------------------------------")
            optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)

        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)

        tf.summary.scalar('learning_rate', lr)
        tf.summary.scalar('emd_loss', emd_loss_out)
        # Build the summary Tensor based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

    with tf.Session(graph=graph) as sess:

        sess.run(tf.global_variables_initializer())
        vgg.load_weights(args.pretrain_weights, sess)

        # create queue coordinator
        coord = tf.train.Coordinator()
        thread = tf.train.start_queue_runners(coord=coord, sess=sess)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(os.path.join(
            args.logs_dir, 'train/{}-{}'.format(args.exp_name, timestamp)),
                                               sess.graph,
                                               filename_suffix=args.exp_name)
        summary_test = tf.summary.FileWriter(os.path.join(
            args.logs_dir, 'test/{}-{}'.format(args.exp_name, timestamp)),
                                             filename_suffix=args.exp_name)
        # global_var = tf.global_variables()
        # var_list = sess.run(global_var)

        import time
        start_time = time.time()
        best_epoch = 0
        base_lr = args.learning_rate
        iters_per_epoch = len(reader.image_list) / args.batch_size
        for step in range(args.iter_max):

            if (step + 1) % (0.5 * args.iter_max) == 0:
                base_lr = base_lr / 5
            if (step + 1) % (0.8 * args.iter_max) == 0:
                base_lr = base_lr / 5

            # base_lr=(base_lr-base_lr*0.001)/args.iter_max*(args)

            image_batch_, label_batch_, mean_std_batch_ = sess.run(
                [image_batch, label_batch, mean_std_batch])
            mean_std_batch_ = mean_std_batch_[:, 0].reshape(-1, 1)
            means_out, mean_hat_out, emd_loss_, l2_loss_, total_loss_, _ = sess.run(
                [means, mean_hat, emd_loss_out, l2_loss, loss, optimizer],
                feed_dict={
                    imgs: image_batch_,
                    scores: label_batch_,
                    means: mean_std_batch_,
                    lr: base_lr
                })

            if step % iters_per_epoch == 0:
                logger.info(
                    "step %d/%d, the emd loss is %f,l2_loss is %f,total loss is %f, time %f,learning rate: %lf"
                    % (step, args.iter_max, emd_loss_, l2_loss_, total_loss_,
                       (time.time() - start_time), base_lr))
                summary_str = sess.run(summary_op,
                                       feed_dict={
                                           imgs: image_batch_,
                                           scores: label_batch_,
                                           means: mean_std_batch_,
                                           lr: base_lr
                                       })
                # print(means_out.reshape(-1,))
                # print(mean_hat_out)
                # srocc, krocc, plcc, rmse, mse = evaluate_metric(means_out.reshape(-1, ), mean_hat_out)
                # logger.info(
                #     "evaluate train batch SROCC_v: %.3f\t KROCC: %.3f\t PLCC_v: %.3f\t RMSE_v: %.3f\t mse: %.3f\n" % (
                #     srocc, krocc, plcc, rmse, mse))

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            if args.save_ckpt_file and step % args.eval_step == 0:  # every epoch save ckpt times
                # saver.save(sess, args.checkpoint_dir + 'iteration_' + str(step) + '.ckpt',write_meta_graph=False)
                save(saver, sess, args.ckpt_dir, step)

                if args.is_eval:
                    score_set = []
                    label_set = []
                    loss_set = []
                    for i in range(len(test_reader.image_list)):
                        image_, score_, mean_std_ = sess.run(
                            [test_image, test_score, test_mean_std])
                        # label_set.append(mean_std_[0])
                        label_set.append(mean_std_[:, 0])
                        # mean_std_ = mean_std_[:, 0].reshape(-1, 1)
                        emd_loss_out_, scores_hat_test = sess.run(
                            [emd_loss_out, scores_hat],
                            feed_dict={
                                imgs: image_,
                                scores: score_
                            })
                        loss_set.append(emd_loss_out_)
                        # mean,std=scores_stats(scores_hat_)
                        mean_test = mean_score(scores_hat_test)
                        score_set.append(mean_test)
                        if i == 10:
                            summary_str = sess.run(summary_op,
                                                   feed_dict={
                                                       imgs: image_,
                                                       scores: score_,
                                                       lr: base_lr
                                                   })
                            summary_test.add_summary(summary_str, step)
                            summary_test.flush()

                    srocc, krocc, plcc, rmse, mse = evaluate_metric(
                        label_set, score_set)
                    print(len(label_set), len(score_set))
                    logger.info(
                        "==============evaluating test datasets :SROCC_v: %.3f\t KROCC: %.3f\t PLCC_v: %.3f\t RMSE_v: %.3f\t mse: %.3f, emd loss is: %f\n"
                        % (srocc, krocc, plcc, rmse, mse, np.mean(loss_set)))

        logger.info("Optimization finish!")
        coord.request_stop()
        coord.join(thread)
Ejemplo n.º 12
0
def run(net, logger, hps):
    # Create dataloaders
    trainloader, valloader, testloader = get_dataloaders(bs=hps['bs'])

    net = net.to(device)

    learning_rate = float(hps['lr'])
    scaler = GradScaler()

    # optimizer = torch.optim.Adadelta(net.parameters(), lr=learning_rate, weight_decay=0.0001)
    # optimizer = torch.optim.Adagrad(net.parameters(), lr=learning_rate, weight_decay=0.0001)
    # optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, amsgrad=True)
    # optimizer = torch.optim.ASGD(net.parameters(), lr=learning_rate, weight_decay=0.0001)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=0.0001)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='max',
                                  factor=0.75,
                                  patience=5,
                                  verbose=True)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5, last_epoch=-1, verbose=True)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=hps['n_epochs'])
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6, last_epoch=-1, verbose=False)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0

    print("Training", hps['name'], "on", device)
    for epoch in range(hps['start_epoch'], hps['n_epochs']):

        acc_tr, loss_tr = train(net, trainloader, criterion, optimizer, scaler)
        logger.loss_train.append(loss_tr)
        logger.acc_train.append(acc_tr)

        acc_v, loss_v = evaluate(net, valloader, criterion)
        logger.loss_val.append(loss_v)
        logger.acc_val.append(acc_v)

        # Update learning rate
        scheduler.step(acc_v)

        if acc_v > best_acc:
            best_acc = acc_v

            save(net, logger, hps, epoch + 1)
            logger.save_plt(hps)

        if (epoch + 1) % hps['save_freq'] == 0:
            save(net, logger, hps, epoch + 1)
            logger.save_plt(hps)

        print('Epoch %2d' % (epoch + 1),
              'Train Accuracy: %2.4f %%' % acc_tr,
              'Val Accuracy: %2.4f %%' % acc_v,
              sep='\t\t')

    # Calculate performance on test set
    acc_test, loss_test = evaluate(net, testloader, criterion)
    print('Test Accuracy: %2.4f %%' % acc_test,
          'Test Loss: %2.6f' % loss_test,
          sep='\t\t')