Exemplo n.º 1
0
if torch.cuda.is_available():
    model_g1.cuda()
    model_g2.cuda()
    model_f1.cuda()
    weight = weight.cuda()

criterion = CrossEntropyLoss2d(weight)

configure(args.tflog_dir, flush_secs=5)

model_g1.train()
model_g2.train()
model_f1.train()
if args.fix_bn:
    print(emphasize_str("BN layers are NOT trained!"))
    fix_batchnorm_when_training(model_g1)
    fix_batchnorm_when_training(model_g2)
    fix_batchnorm_when_training(model_f1)

    # check_training(model)

for epoch in range(start_epoch, args.epochs):
    epoch_loss = 0
    for ind, (images, labels) in tqdm(enumerate(train_loader)):

        imgs = Variable(images)
        lbls = Variable(labels)
        if torch.cuda.is_available():
            imgs, lbls = imgs.cuda(), lbls.cuda()

        # update generator and classifiers by source samples
model_g_3ch.train()
model_g_1ch.train()
model_f1.train()
model_f2.train()

if args.no_dropout:
    print("NO DROPOUT")
    fix_dropout_when_training(model_g_3ch)
    fix_dropout_when_training(model_g_1ch)
    fix_dropout_when_training(model_f1)
    fix_dropout_when_training(model_f2)

if args.fix_bn:
    print(emphasize_str("BN layers are NOT trained!"))
    fix_batchnorm_when_training(model_g_3ch)
    fix_batchnorm_when_training(model_g_1ch)
    fix_batchnorm_when_training(model_f1)
    fix_batchnorm_when_training(model_f2)

for epoch in range(args.epochs):
    d_loss_per_epoch = 0
    c_loss_per_epoch = 0
    for ind, (source, target) in tqdm.tqdm(enumerate(train_loader)):
        src_imgs, src_lbls = Variable(source[0]), Variable(source[1])
        tgt_imgs = Variable(target[0])

        if torch.cuda.is_available():
            src_imgs, src_lbls, tgt_imgs = src_imgs.cuda(), src_lbls.cuda(
            ), tgt_imgs.cuda()
Exemplo n.º 3
0
    'pin_memory': True
} if torch.cuda.is_available() else {}
train_loader = torch.utils.data.DataLoader(src_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)

if torch.cuda.is_available():
    model.cuda()

configure(args.tflog_dir, flush_secs=5)

model.train()
if args.fix_bn:
    print(emphasize_str("BN layers are NOT trained!"))
    fix_batchnorm_when_training(model)

    # check_training(model)

for epoch in range(start_epoch, args.epochs):
    epoch_loss = 0
    epoch_base_loss = 0
    epoch_uncertain_loss = 0
    for ind, (images, labels) in tqdm(enumerate(train_loader)):

        imgs = Variable(images)
        lbls = Variable(labels)
        if torch.cuda.is_available():
            imgs, lbls = imgs.cuda(), lbls.cuda()

        # update generator and classifiers by source samples
if torch.cuda.is_available():
    model_enc.cuda()
    model_dec.cuda()
    weight = weight.cuda()

model_enc.train()
model_dec.train()

if args.no_dropout:
    print("NO DROPOUT")
    fix_dropout_when_training(model_enc)
    fix_dropout_when_training(model_dec)

if args.fix_bn:
    print(emphasize_str("BN layers are NOT trained!"))
    fix_batchnorm_when_training(model_enc)
    fix_batchnorm_when_training(model_dec)

for epoch in range(start_epoch, args.epochs):
    c_loss_per_epoch = 0
    d_loss_per_epoch = 0

    src_semseg_loss_per_epoch = 0
    src_boundary_loss_per_epoch = 0
    tgt_psuedo_boundary_loss_per_epoch = 0
    src_extra_boundary_loss_per_epoch = 0

    for ind, (source, target) in tqdm.tqdm(enumerate(train_loader)):
        src_imgs, src_gt_semseg = Variable(source[0]), Variable(source[1])
        tgt_imgs = Variable(target[0])
    model_f.cuda()
    model_d.cuda()
    weight = weight.cuda()

criterion = CrossEntropyLoss2d(weight)
criterion_d = nn.CrossEntropyLoss()

ploter = LinePlotter()
configure(tflog_dir, flush_secs=5)

model_g.train()
model_f.train()
model_d.train()
if args.fix_bn:
    print(emphasize_str("BN layers are NOT trained!"))
    fix_batchnorm_when_training(model_g)
    fix_batchnorm_when_training(model_f)
    fix_batchnorm_when_training(model_d)

src_domain_lbl = Variable(torch.ones(args.batch_size).long())
tgt_domain_lbl = Variable(torch.zeros(args.batch_size).long())

for epoch in range(args.start_epoch, args.epochs):
    d_loss_per_epoch = 0
    c_loss_per_epoch = 0

    for ind, (source, target) in tqdm.tqdm(enumerate(train_loader)):
        src_imgs, src_lbls = Variable(source[0]), Variable(source[1])
        tgt_imgs = Variable(target[0])

        if torch.cuda.is_available():