Beispiel #1
0
def train(params, log):
    # specify dataset
    data = DatasetFactory.create(params)

    # specify model
    model = ModelFactory.create(params)

    # define loss function (criterion)
    criterion = set_loss_function(params)

    # optimizer & scheduler & load from checkpoint
    optimizer, scheduler, start_epoch, best_prec = set_optimizer_scheduler(params, model, log)

    # log details
    log_string = "\n" + "==== NET MODEL:\n" + str(model)
    log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n"
    log_string += "\n" + "==== SCHEDULER:\n" + str(scheduler) + "\n"
    log_string += "\n" + "==== DATASET (TRAIN):\n" + str(data.dataset['train']) + "\n"
    log_string += "\n" + "==== DATASET (VAL):\n" + repr(data.dataset['val']) + "\n"
    log.log_global(log_string)

    # train
    for epoch in range(start_epoch, params['TRAIN']['epochs']):

        # train for one epoch
        _, _ = train_epoch(data.loader['train'], model, criterion, optimizer, scheduler, epoch,
                           params['device'], log)

        # evaluate on train set
        acc_train, loss_train = validate(data.loader['train'], model, criterion, params['device'])

        # evaluate on validation set
        acc_val, loss_val = validate(data.loader['val'], model, criterion, params['device'])

        # remember best prec@1
        is_best = acc_val > best_prec
        best_prec = max(acc_val, best_prec)

        # save checkpoint
        if params['LOG']['do_checkpoint']:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler
            }, model, params, is_best)

        # logging results
        time_string = log.timers['global'].current2str()  # get current time
        log.log_epoch(epoch + 1,
                      acc_train, loss_train,
                      acc_val, loss_val,
                      is_best, time_string)
Beispiel #2
0
def train(params, log, time_keeper):
    # specify dataset
    data = DatasetFactory.create(params)

    # specify model
    model = ModelFactory.create(params)
    model = model.to(params['device'])

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(params['device'])
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=params['TRAIN']['lr'],
                                momentum=params['TRAIN']['momentum'])

    # resume from a checkpoint
    if len(params['TRAIN']['resume']) > 0:
        start_epoch, best_prec = load_checkpoint(log, model,
                                                 params['TRAIN']['resume'],
                                                 optimizer)
    else:
        start_epoch = 0
        best_prec = 0

    # scheduler (if any)
    if 'lr_schedule_step' in params['TRAIN']:
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=params['TRAIN']['lr_schedule_step'],
            gamma=params['TRAIN']['lr_schedule_gamma'])
    else:
        scheduler = None

    # log details
    log_string = "\n" + "==== NET MODEL:\n" + str(model)
    log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n"
    log.log_global(log_string)

    time_keeper.start()

    # train
    for epoch in range(start_epoch, params['TRAIN']['epochs']):
        # adjust_learning_rate
        if scheduler:
            scheduler.step()

        # train for one epoch
        _, _ = train_epoch(data.loader['train'], model, criterion, optimizer,
                           epoch, params['device'], log, timer)

        # evaluate on train set
        acc_train, loss_train = validate(data.loader['train'], model,
                                         criterion, params['device'])

        # evaluate on validation set
        acc_val, loss_val = validate(data.loader['val'], model, criterion,
                                     params['device'])

        # remember best prec@1 and save checkpoint
        is_best = acc_val > best_prec
        best_prec = max(acc_val, best_prec)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler
            }, model, params, is_best)

        # logging results
        time_string = time_keeper.get_current_str()  # get current time
        log.log_epoch(epoch + 1, acc_train, loss_train, acc_val, loss_val,
                      is_best, time_string)
Beispiel #3
0
def train(params, log, time_keeper, tboard_exp_path):
    writer = SummaryWriter(tboard_exp_path)

    # specify dataset
    data = DatasetFactory.create(params)

    # specify model
    model = ModelFactory.create(params)
    model = model.to(params['device'])

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(params['device'])

    if params['MODEL']['name'] == 'resnet18':
        optimizer = torch.optim.SGD(model.fc.parameters(),
                                    lr=params['TRAIN']['lr'],
                                    momentum=params['TRAIN']['momentum'])
    elif params['MODEL']['name'] == 'vgg19':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=params['TRAIN']['lr'],
                                    momentum=params['TRAIN']['momentum'])
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=params['TRAIN']['lr'],
                                    momentum=params['TRAIN']['momentum'])

    # DRAFT FOR EXPERIMENT
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=6,
                                                       gamma=0.1)

    # resume from a checkpoint
    if len(params['TRAIN']['resume']) > 0:
        start_epoch, best_prec = load_checkpoint(log, model,
                                                 params['TRAIN']['resume'],
                                                 optimizer)
    else:
        start_epoch = 0
        best_prec = 0

    # scheduler (if any)
    if 'lr_schedule_step' in params['TRAIN']:
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=params['TRAIN']['lr_schedule_step'],
            gamma=params['TRAIN']['lr_schedule_gamma'])
    else:
        scheduler = None

    # log details
    log_string = "\n" + "==== NET MODEL:\n" + str(model)
    log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n"
    log.log_global(log_string)

    time_keeper.start()

    # train
    for epoch in range(start_epoch, params['TRAIN']['epochs']):
        # adjust_learning_rate
        if scheduler:
            scheduler.step()

        # train for one epoch
        _, _ = train_epoch(data.loader['train'], model, criterion, optimizer,
                           epoch, params['device'], log, timer, writer,
                           exp_lr_scheduler)

        # evaluate on train set
        acc_train, loss_train = validate(data.loader['train'], model,
                                         criterion, params['device'])

        # evaluate on validation set
        acc_val, loss_val = validate(data.loader['val'], model, criterion,
                                     params['device'])

        correct = 0
        total = 0

        class_correct = list(0. for i in range(len(classes)))
        class_total = list(0. for i in range(len(classes)))

        hoodie_not_correct = []

        with torch.no_grad():
            for (input_data, target) in data.loader['train']:
                images = input_data.to(params['device'])
                labels = target.to(params['device'])
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                c = (predicted == labels).squeeze()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                for j in range(4):
                    label = labels[j]
                    class_correct[label] += c[j].item()
                    class_total[label] += 1

        print('Accuracy of the network on the % epoch test images: %d %%' %
              (epoch, 100 * correct / total))
        for k in range(len(classes)):
            print('Accuracy of %5s : %2d %%' %
                  (classes[k], 100 * class_correct[k] / class_total[k]))
            writer.add_scalar('Accuracy of %5s' % classes[k],
                              (100 * class_correct[k] / class_total[k]),
                              epoch + 1)

        # remember best prec@1 and save checkpoint
        is_best = acc_val > best_prec
        best_prec = max(acc_val, best_prec)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler
            }, model, params, is_best)

        # logging results
        time_string = time_keeper.get_current_str()  # get current time
        log.log_epoch(epoch + 1, acc_train, loss_train, acc_val, loss_val,
                      is_best, time_string)
        writer.add_scalar("Train accuracy ", acc_train, epoch + 1)
        writer.add_scalar("Train Loss ", loss_train, epoch + 1)
        writer.add_scalar("Test accuracy ", acc_val, epoch + 1)
        writer.add_scalar("Test Loss ", loss_val, epoch + 1)
        exp_lr_scheduler.step()
Beispiel #4
0
def train(params, log, time_keeper):
    # specify dataset
    dataset = DatasetFactory.create(params)

    # specify model
    model = ModelFactory.create(params)
    model = model.to(params['device'])

    # optiimizer
    optimizer = SGD(model.parameters(),
                    lr=params['TRAIN']['lr'],
                    momentum=params['TRAIN']['momentum'])

    # scheduler
    scheduler = None

    # best accuracy(precision)
    best_prec = 0

    # optionally resume from a checkpoint
    checkpoint_file = params['TRAIN']['resume']
    start_epoch, best_prec = load_checkpoint(log, model, checkpoint_file,
                                             optimizer, scheduler)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.cross_entropy,
                                        device=params['device'])

    # evaluator
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy':
                                                CategoricalAccuracy(),
                                                'cross_entropy':
                                                Loss(F.cross_entropy)
                                            },
                                            device=params['device'])
    # log details
    log_string = "\n" + "==== NET MODEL:\n" + str(model)
    log_string += "\n" + "==== OPTIMIZER:\n" + str(optimizer) + "\n"
    log.log_global(log_string)

    # end-of-iteration events
    @trainer.on(Events.ITERATION_COMPLETED)
    def on_iter(engine):
        iter_current = engine.state.iteration % len(dataset.loader['train'])
        epoch_current = engine.state.epoch
        num_iter = len(dataset.loader['train'])
        loss = engine.state.output

        # logging
        time_string = time_keeper.get_current_str()  # get current time
        log.log_iter(iter_current, epoch_current - 1, num_iter, loss,
                     time_string)

    # end-of-epoch events
    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch(engine):
        nonlocal best_prec

        # current epoch
        epoch_current = engine.state.epoch

        # evaluation on train set
        evaluator.run(dataset.loader['train'])
        acc_train = evaluator.state.metrics['accuracy'] * 100
        loss_train = evaluator.state.metrics['cross_entropy']

        # evaluation on val set
        evaluator.run(dataset.loader['val'])
        acc_val = evaluator.state.metrics['accuracy'] * 100
        loss_val = evaluator.state.metrics['cross_entropy']

        is_best = acc_val > best_prec
        best_prec = max(acc_val, best_prec)
        save_checkpoint(
            {
                'epoch': epoch_current + 1,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler
            }, model, params, is_best)

        # logging results
        time_string = time_keeper.get_current_str()  # get current time
        log.log_epoch(epoch_current, acc_train, loss_train, acc_val, loss_val,
                      is_best, time_string)

    time_keeper.start()
    trainer.run(dataset.loader['train'], max_epochs=params['TRAIN']['epochs'])