示例#1
0
def train(learning_rate, learning_rate_decay, learning_rate_decay_step_size,
          batch_size, num_of_epochs, img_size, arch):
    # check device
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    # parameters
    RANDOM_SEED = 42
    N_CLASSES = 3

    # Load Data
    dataset = PoseDataset(csv_file='./labels.csv',
                          img_size=img_size,
                          transform=transforms.ToTensor())

    train_set, test_set = torch.utils.data.random_split(
        dataset,
        [int(np.ceil(0.8 * len(dataset))),
         int(np.floor(0.2 * len(dataset)))])

    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_set,
                             batch_size=batch_size,
                             shuffle=True)

    # instantiate the model
    torch.manual_seed(RANDOM_SEED)

    if arch == 'simple':
        model = Classifier(N_CLASSES).to(DEVICE)

    elif arch == 'resnet':
        model = ResClassifier(N_CLASSES).to(DEVICE)

    else:
        print(
            'model architecture not supported, you can use simple and resnet only!'
        )
        return

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=learning_rate_decay_step_size,
                                    gamma=learning_rate_decay)

    cross_entropy_loss_criterion = nn.CrossEntropyLoss()

    print('start training...')
    # start training
    model, optimizer, train_losses, valid_losses = training_loop(
        model, cross_entropy_loss_criterion, batch_size, optimizer, scheduler,
        num_of_epochs, train_loader, test_loader, DEVICE)
示例#2
0
source_loader = torch.utils.data.DataLoader(source_set,
                                            batch_size=args.batch_size,
                                            shuffle=args.shuffle,
                                            num_workers=args.num_workers)
target_loader = torch.utils.data.DataLoader(target_set,
                                            batch_size=args.batch_size,
                                            shuffle=args.shuffle,
                                            num_workers=args.num_workers)

if args.model == 'resnet101':
    netG = ResBase101().cuda()
elif args.model == 'resnet50':
    netG = ResBase50().cuda()
else:
    raise ValueError('Unexpected value of args.model')
netF = ResClassifier(class_num=args.class_num, extract=args.extract).cuda()
netF.apply(weights_init)


def get_L2norm_loss_self_driven(x):
    l = (x.norm(p=2, dim=1).mean() - args.radius)**2
    return args.weight_ring * l


def get_cls_loss(pred, gt):
    cls_loss = F.nll_loss(F.log_softmax(pred), gt)
    return cls_loss


opt_g = optim.SGD(netG.parameters(), lr=args.lr, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(),
示例#3
0
data_transform = transforms.Compose([
    transforms.Scale((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

t_set = VisDAImage(t_root, t_label, data_transform)
assert len(t_set) == 28978
t_loader = torch.utils.data.DataLoader(t_set,
                                       batch_size=args.batch_size,
                                       shuffle=args.shuffle,
                                       num_workers=args.num_workers)

netG = ResBase50().cuda()
netF = ResClassifier(class_num=args.class_num).cuda()
netG.eval()
netF.eval()

for epoch in range(args.epoch / 2, args.epoch + 1):
    if epoch % 10 != 0:
        continue

    netG.load_state_dict(
        torch.load(
            os.path.join(
                args.snapshot, "VisDA_IAFN_netG_" + args.post + '.' +
                str(args.repeat) + '_' + str(epoch) + ".pth")))
    netF.load_state_dict(
        torch.load(
            os.path.join(
示例#4
0
文件: eval.py 项目: yiyang-wang/AFN
    transforms.Scale((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

t_set = OfficeImage(t_root, t_label, data_transform)
# assert len(t_set) == 795
t_loader = torch.utils.data.DataLoader(t_set,
                                       batch_size=args.batch_size,
                                       shuffle=args.shuffle,
                                       num_workers=args.num_workers)

netG = ResBase50().cuda()
netF = ResClassifier(class_num=args.class_num,
                     extract=False,
                     dropout_p=args.dropout_p).cuda()
netG.eval()
netF.eval()

for epoch in range(args.epoch / 2, args.epoch + 1):
    if epoch % 10 != 0:
        continue
    netG.load_state_dict(
        torch.load(
            os.path.join(
                args.snapshot, "Office31_HAFN_" + args.task + "_netG_" +
                args.post + "." + args.repeat + "_" + str(epoch) + ".pth")))
    netF.load_state_dict(
        torch.load(
            os.path.join(
示例#5
0
文件: eval.py 项目: redhat12345/AFN
    transforms.Scale((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

t_set = OfficeImage(t_root, t_label, data_transform)
assert len(t_set) == get_dataset_length(args.target + '_shared')

t_loader = torch.utils.data.DataLoader(t_set,
                                       batch_size=args.batch_size,
                                       shuffle=args.shuffle,
                                       num_workers=args.num_workers)

netG = ResBase50().cuda()
netF = ResClassifier(class_num=args.class_num, extract=False).cuda()
netG.eval()
netF.eval()

for epoch in range(1, args.epoch + 1):
    if epoch % 10 != 0:
        continue
    netG.load_state_dict(
        torch.load(
            os.path.join(
                args.snapshot, "OfficeHome_IAFN_" + args.task + "_netG_" +
                args.post + '.' + args.repeat + '_' + str(epoch) + ".pth")))
    netF.load_state_dict(
        torch.load(
            os.path.join(
                args.snapshot, "OfficeHome_IAFN_" + args.task + "_netF_" +
示例#6
0
文件: train.py 项目: yiyang-wang/AFN
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

source_set = OfficeImage(source_root, source_label, train_transform)
target_set = OfficeImage(target_root, target_label, train_transform)

source_loader = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size,
    shuffle=args.shuffle, num_workers=args.num_workers)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size,
    shuffle=args.shuffle, num_workers=args.num_workers)

netG = ResBase50().cuda()
netF = ResClassifier(class_num=args.class_num, extract=args.extract, dropout_p=args.dropout_p).cuda()
netF.apply(weights_init)


def get_cls_loss(pred, gt):
    cls_loss = F.nll_loss(F.log_softmax(pred), gt)
    return cls_loss

def get_L2norm_loss_self_driven(x):
    l = (x.norm(p=2, dim=1).mean() - args.radius) ** 2
    return args.weight_L2norm * l

opt_g = optim.SGD(netG.parameters(), lr=args.lr, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005)
                    
for epoch in range(1, args.pre_epoches + 1):
示例#7
0
文件: train.py 项目: redhat12345/AFN
target_set = VisDAImage(target_root, target_label, train_transform)
assert len(source_set) == 152397
assert len(target_set) == 55388
source_loader = torch.utils.data.DataLoader(source_set, batch_size=args.batch_size,
    shuffle=args.shuffle, num_workers=args.num_workers)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size,
    shuffle=args.shuffle, num_workers=args.num_workers)

if args.model == 'resnet101':
    netG = ResBase101().cuda()
elif args.model == 'resnet50':
    netG = ResBase50().cuda()
else:
    raise ValueError('Unexpected value of args.model')
    
netF = ResClassifier(class_num=args.class_num, extract=args.extract).cuda()
netF.apply(weights_init)


def get_cls_loss(pred, gt):
    cls_loss = F.nll_loss(F.log_softmax(pred), gt)
    return cls_loss

def get_L2norm_loss_self_driven(x):
    radius = x.norm(p=2, dim=1).detach()
    assert radius.requires_grad == False
    radius = radius + 0.3
    l = ((x.norm(p=2, dim=1) - radius) ** 2).mean()
    return args.weight_L2norm * l