コード例 #1
0
def train(model, loss_fn, optimizer, sampler, val_sampler=None, last_iter=0,
          train_writer=None, val_writer=None, monitor=None, **params):
    """ Generalized training function """

    assert params_defined(params), "Params under-specified"

    if monitor is None:
        monitor = utils.LearningMonitor()

    #Determine the names of inputs, labels, masks
    sample_spec = utils.SampleSpec(sampler().keys())
    mask_names = sample_spec.get_masks()

    print("======= BEGIN TRAINING LOOP ========")
    for i in range(last_iter, params['max_iter']):
        start = time.time()

        # Make sure no mask is empty (data for all tasks)
        sample = fetch_nonempty_sample(sampler, mask_names, params['batch_size'])

        inputs, labels, masks = group_sample(sample, sample_spec, "train")

        #Running forward pass
        preds = model(*inputs)

        losses, nmsks = eval_error(preds, labels, masks, loss_fn, sample_spec)

        update_model(optimizer, losses)

        log_errors(monitor, losses, nmsks, i)

        # Elapsed time.
        elapsed = time.time() - start
        log_elapsed_time(monitor, elapsed, i, "train")

        if val_sampler is not None and i % params["test_intv"] == 0:
            run_validation(model, val_sampler, params["test_iter"],
                           loss_fn, sample_spec, monitor, val_writer, i)

        if i % params["avgs_intv"] == 0 or i < last_iter + params["warm_up"]-1:
            monitor.compute_avgs(i, "train")

            #Displaying stats (both to console and TensorBoard)
            avg_losses = { k : round(monitor.get_last_value(k, "train"),5)
                           for k in losses.keys() }
            avg_time = round(monitor.get_last_value("iter_time","train"),5)

            write_averages(train_writer, avg_losses, avg_time, i)
            print("iter: {}; avg losses = {} (iter_time = {} s on avg)".format(i,avg_losses, avg_time))

        if i % params["chkpt_intv"] == 0 and i != last_iter:
            print("SAVE CHECKPOINT: {} iters.".format(i))
            utils.save_chkpt(model, monitor, i, params["model_dir"],
                             params["log_dir"])
コード例 #2
0
def _train_SimSiam(device,trainloader,bankloader,queryloader,
                   model,optimizer,scheduler,num_epochs,base_dir,
                  saved_epoch=0,loss_hist=[],knn_hist=[],best_acc=80):
    model.to(device)
    scaler = torch.cuda.amp.GradScaler()
    start_time = time.time()
    for epoch in range(saved_epoch,num_epochs):
        model.train()
        epoch_time = time.time()

        for b,(x1,x2,label,_) in enumerate(trainloader):
            x1 = x1.to(device)
            x2 = x2.to(device)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                z1,z2,p1,p2 = model(x1,x2)
                loss = D(p1,z2)/2 + D(p2,z1)/2
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if b%40 == 0:
                loss_hist.append(loss.item())
                now = datetime.datetime.now()
                print("epoch:{} batch:{} loss:{:.4f} lr:{:.2e} time:{}".format(epoch,b,loss.item(),get_lr(optimizer),
                                                                               now.strftime('%Y-%m-%d %H:%M:%S')))

        # adjust lr by cosine annealing
        scheduler.step()
        correct,total = knn_monitor(device,model,bankloader,queryloader)
        accuracy = correct/total*100
        knn_hist.append(accuracy)
        print("epoch:{} knn_acc:{:.2f}% ({}/{}) epoch_time:{:.2f}\n".format(epoch,accuracy,correct,total,
                                                                            time.time()-epoch_time))
        if accuracy > best_acc:
            save_dir = base_dir+"SimSiam_{}.chkpt".format(int(accuracy*100))
            save_chkpt(model,optimizer,scheduler,epoch,loss_hist,knn_hist,base_dir+save_dir)
            best_acc = accuracy
            print("---------------- model saved: {:.2f}% ----------------".format(accuracy))
    print("training finished!\ntotal time:{:.2f}, best_acc:{:.2f}".format(time.time()-start_time,best_acc))
コード例 #3
0
ファイル: train.py プロジェクト: agataf/2DUNet
def train(model,
          loss_fn,
          optimizer,
          sampler,
          val_sampler=None,
          last_iter=0,
          monitor=None,
          **params):
    """ Generalized training fn """

    assert params_defined(params), "Params under-specified"

    if monitor is None:
        monitor = utils.LearningMonitor()

    #Determine the names of inputs, labels, masks
    sample_spec = utils.SampleSpec(sampler.get().keys())
    mask_names = sample_spec.get_masks()

    start = time.time()
    print("======= BEGIN TRAINING LOOP ========")
    for i in range(last_iter, params['max_iter']):

        # Make sure no mask is empty (data for all tasks)
        sample = fetch_nonempty_sample(sampler, mask_names,
                                       params['batch_size'])
        #print("sample type, size in training loop", type(sample), sample.get(sample.keys()[0]).shape)
        inputs, labels, masks = make_variables(sample, sample_spec, "train")

        #Running forward pass
        preds = model(*inputs)

        if (params["resize"] != 1):
            print("Type of Preds[0]:", type(preds[0]))
            #preds = misc.imresize(preds, 1.0*params["resize"], interp="bilinear")
            #print("Resized!")

        losses, nmsks = eval_error(preds, labels, masks, loss_fn, sample_spec)

        update_model(optimizer, losses)

        log_errors(monitor, losses, nmsks)

        # Elapsed time.
        elapsed = time.time() - start
        log_elapsed_time(monitor, elapsed, "train")
        start = time.time()

        if val_sampler is not None and i % params["test_intv"] == 0:
            run_validation(model, val_sampler, params["test_iter"], loss_fn,
                           sample_spec, monitor, i)
            start = time.time()  #ignore validation time

        if i % params[
                "avgs_intv"] == 0 or i < last_iter + params["warm_up"] - 1:
            monitor.compute_avgs(i, "train")

            #Displaying stats
            avg_losses = {
                k: round(monitor.get_last_value(k, "train"), 5)
                for k in losses.keys()
            }
            avg_time = round(monitor.get_last_value("iter_time", "train"), 5)
            print("iter: {}; avg losses = {} (iter_time = {} s on avg)".format(
                i, avg_losses, avg_time))

        if i % params["chkpt_intv"] == 0 and i != last_iter:
            print("SAVE CHECKPOINT: {} iters.".format(i))
            utils.save_chkpt(model, monitor, i, params["model_dir"],
                             params["log_dir"])
コード例 #4
0
                                weight_decay=args.weightDecay)

    # Learning
    for epoch_num in range(args.initEpochNum,
                           args.initEpochNum + args.nEpochs):
        trn_metrics = runModel(trn_data_gen,
                               model,
                               optimizer,
                               class_wts,
                               'trn',
                               args.batchSize,
                               trn_num_batches,
                               loss_wts=loss_wts)
        utils.log_metrics(epoch_num, trn_metrics, 'trn', log_file,
                          args.savename)
        torch.save(model.state_dict(), args.savename + '.pt')
        val_metrics = runModel(val_data_gen, model, optimizer, class_wts,
                               'val', args.batchSize, val_num_batches, None)
        utils.log_metrics(epoch_num, val_metrics, 'val', log_file,
                          args.savename)
        if best_val_record and val_metrics.AUROC > best_val:
            best_val = utils.save_chkpt(best_val_record, best_val, val_metrics,
                                        model, args.savename)
    tst_metrics = runModel(tst_data_gen, model, optimizer, class_wts, 'tst',
                           args.batchSize, tst_num_batches, None)
    utils.log_metrics(0, tst_metrics, 'tst', log_file, args.savename)
    # val_aggregator = Aggregator('val', task, val_data_loader)
    # val_aggregator.aggregate()
    # tst_aggregator = Aggregator('tst', task, tst_data_loader)
    # tst_aggregator.aggregate()