else:
    channels = 3
if (args.source == 'cifar'
        and args.target == 'stl') or (args.source == 'stl'
                                      and args.target == 'cifar'):
    classes = 9
else:
    classes = 10

print('==> Building model..')
net = ResNet(args.depth, args.width, classes=classes, channels=channels).cuda()
ext = extractor_from_layer3(net)

print('==> Preparing datasets..')
sc_tr_dataset, sc_te_dataset = prepare_dataset(args.source,
                                               image_size,
                                               channels,
                                               path=args.data_root)
sc_tr_loader = torchdata.DataLoader(sc_tr_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=4)
sc_te_loader = torchdata.DataLoader(sc_te_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=4)

tg_tr_dataset, tg_te_dataset = prepare_dataset(args.target,
                                               image_size,
                                               channels,
                                               path=args.data_root)
tg_te_loader = torchdata.DataLoader(tg_te_dataset,
Example #2
0
def launch_train(args):
    image_size = 32
    num_workers = 1
    if (args.source == 'usps'
            and args.target == 'mnist') or (args.source == 'mnist'
                                            and args.target == 'usps'):
        channels = 1
    else:
        channels = 3
    if (args.source == 'cifar'
            and args.target == 'stl') or (args.source == 'stl'
                                          and args.target == 'cifar'):
        classes = 9
    else:
        classes = 10

    print('==> Building model..')
    net = ResNet(args.depth, args.width, classes=classes,
                 channels=channels).cuda()
    ext = extractor_from_layer3(net)

    print('==> Preparing datasets..')
    sc_tr_dataset, sc_te_dataset = prepare_dataset(args.source,
                                                   image_size,
                                                   channels,
                                                   path=args.data_root)
    sc_tr_loader = torchdata.DataLoader(sc_tr_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=num_workers,
                                        pin_memory=True)
    sc_te_loader = torchdata.DataLoader(sc_te_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=num_workers,
                                        pin_memory=True)

    tg_tr_dataset, tg_te_dataset = prepare_dataset(args.target,
                                                   image_size,
                                                   channels,
                                                   path=args.data_root)
    tg_te_loader = torchdata.DataLoader(tg_te_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=num_workers,
                                        pin_memory=True)
    sstasks = parse_tasks(args, ext, sc_tr_dataset, sc_te_dataset,
                          tg_tr_dataset, tg_te_dataset)

    criterion = nn.CrossEntropyLoss().cuda()
    parameters = list(net.parameters())
    for sstask in sstasks:
        parameters += list(sstask.head.parameters())
    optimizer = optim.SGD(parameters,
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, [args.milestone_1, args.milestone_2],
        gamma=0.1,
        last_epoch=-1)

    all_epoch_stats = []
    print('==> Running..')
    for epoch in range(1, args.nepoch + 1):
        print('Source epoch %d/%d lr=%.3f' %
              (epoch, args.nepoch, optimizer.param_groups[0]['lr']))
        print('Error (%)\t\tmmd\ttarget test\tsource test\tunsupervised test')

        epoch_stats = train(args, net, ext, sstasks, criterion, optimizer,
                            scheduler, sc_tr_loader, sc_te_loader,
                            tg_te_loader, epoch)
        all_epoch_stats.append(epoch_stats)
        torch.save(all_epoch_stats, args.outf + '/loss.pth')
        plot_all_epoch_stats(all_epoch_stats, args.outf)