示例#1
0
def train_model(model,
                dataloader,
                cartoon_dataloader,
                portrait_dataloader,
                hyper_parameters,
                optimizer,
                scheduler=None,
                device="cpu",
                num_epochs=500,
                save_name='best.pt'):
    since = time.time()
    test_sketch_acc_history = []
    epoch_loss_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_recall = [0.0, 0.0, 0.0]

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 20)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                # Set model to training mode
                model.train()
            else:
                # Set model to evaluate mode
                model.eval()

            running_loss = 0.0
            # Iterate over data.
            for cartoons, cartoon_labels, portraits, portrait_labels in dataloader[
                    phase]:
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    cartoons = cartoons.to(device)
                    cartoon_labels = cartoon_labels.to(device)
                    portraits = portraits.to(device)
                    portrait_labels = portrait_labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()
                    #print('a' * 20)

                    # Forward
                    cartoons_feature, portraits_feature, cartoons_predict, portraits_predict = model(
                        cartoons, portraits)
                    #print('b' * 20)
                    loss = calc_loss(cartoons_feature, portraits_feature,
                                     cartoons_predict, portraits_predict,
                                     cartoon_labels, portrait_labels,
                                     hyper_parameters)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        if scheduler:
                            scheduler.step()

                # statistics
                running_loss += loss.item()

            epoch_loss = running_loss / len(cartoon_dataloader[phase].dataset)
            if phase == 'train':
                print('{} Loss: {:.4f}'.format(phase, epoch_loss))
                continue

            t_cartoon_features, t_cartoon_labels, t_portrait_features, t_portrait_labels = [], [], [], []
            with torch.no_grad():
                for cartoons, cartoon_labels in cartoon_dataloader[phase]:
                    cartoons = cartoons.to(device)
                    cartoon_labels = cartoon_labels.to(device)

                    cartoons_feature, _ = model(cartoons=cartoons)
                    t_cartoon_features.append(cartoons_feature.cpu().numpy())
                    t_cartoon_labels.append(
                        cartoon_labels.cpu().squeeze(-1).numpy())

                for portraits, portrait_labels in portrait_dataloader[phase]:
                    portraits = portraits.to(device)
                    portrait_labels = portrait_labels.to(device)

                    portraits_feature, _ = model(portraits=portraits)
                    t_portrait_features.append(portraits_feature.cpu().numpy())
                    t_portrait_labels.append(
                        portrait_labels.cpu().squeeze(-1).numpy())
            t_cartoon_features = np.concatenate(t_cartoon_features)
            t_cartoon_labels = np.concatenate(t_cartoon_labels)
            t_portrait_features = np.concatenate(t_portrait_features)
            t_portrait_labels = np.concatenate(t_portrait_labels)

            Sketch2Video_map = fx_calc_map_label(t_cartoon_features,
                                                 t_cartoon_labels,
                                                 t_portrait_features,
                                                 t_portrait_labels)
            #Video2Sketch = fx_calc_map_label(t_videos, t_sketches, t_labels)
            Sketch2Video = fx_calc_recall(t_cartoon_features, t_cartoon_labels,
                                          t_portrait_features,
                                          t_portrait_labels)
            #Video2Sketch = fx_calc_recall(t_videos, t_sketches, t_labels)

            #print('{} Loss: {:.4f} Sketch2Video: {:.4f}  Video2Sketch: {:.4f}'.format(phase, epoch_loss, Sketch2Video, Video2Sketch))
            print(
                '{} Loss: {:.4f} Cartoon2Real: mAP = {:.4f} R1 = {:.4f} R5 = {:.4f} R10 = {:.4f}'
                .format(phase, epoch_loss, Sketch2Video_map, Sketch2Video[0],
                        Sketch2Video[1], Sketch2Video[2]))

            # deep copy the model
            #Sketch2Video_mean = np.mean(Sketch2Video)
            if phase == 'valid' and Sketch2Video_map > best_acc:
                best_acc = Sketch2Video_map
                best_recall = Sketch2Video
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'valid':
                test_sketch_acc_history.append(Sketch2Video_map)
                epoch_loss_history.append(epoch_loss)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best average ACC: {:4f}'.format(best_acc))
    print('Best recall: R1 = {:.4f} R5 = {:.4f} R10 = {:.4f}'.format(
        best_recall[0], best_recall[1], best_recall[2]))

    save_folder = 'weights'
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    torch.save(best_model_wts, os.path.join(save_folder, save_name))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, test_sketch_acc_history, epoch_loss_history
示例#2
0
                cartoons_feature, _ = model(cartoons=cartoons)
                t_cartoon_features.append(cartoons_feature.cpu().numpy())
                t_cartoon_names.append(cartoon_names.cpu().squeeze(-1).numpy())

            for portraits, portrait_names in portrait_dataloader['valid']:
                portraits = portraits.to(device)
                portrait_names = portrait_names.to(device)
                #portrait_names = np.asarray(portrait_names)
                portraits_feature, _ = model(portraits=portraits)
                t_portrait_features.append(portraits_feature.cpu().numpy())
                t_portrait_names.append(
                    portrait_names.cpu().squeeze(-1).numpy())

        t_cartoon_features = np.concatenate(t_cartoon_features)
        t_cartoon_names = np.concatenate(t_cartoon_names)
        t_portrait_features = np.concatenate(t_portrait_features)
        t_portrait_names = np.concatenate(t_portrait_names)

        Sketch2Video_map = fx_calc_map_label(t_cartoon_features,
                                             t_cartoon_names,
                                             t_portrait_features,
                                             t_portrait_names)
        Sketch2Video = fx_calc_recall(t_cartoon_features, t_cartoon_names,
                                      t_portrait_features, t_portrait_names)
        print(
            'Sketch2Video: mAP = {:.4f} R1 = {:.4f} R5 = {:.4f} R10 = {:.4f}'.
            format(Sketch2Video_map, Sketch2Video[0], Sketch2Video[1],
                   Sketch2Video[2]))

        print('...Validating is completed...')
示例#3
0
文件: main.py 项目: daisystar/ALGCN
    # params_to_update = list(model_ft.parameters())
    params_to_update = model_ft.get_config_optim(lr)

    # Observe that all parameters are being optimized
    optimizer = optim.Adam(params_to_update, lr=lr, betas=betas)
    if EVAL:
        model_ft.load_state_dict(torch.load('model/ALGCN_' + dataset + '.pth'))
    else:
        print('...Training is beginning...')
        # Train and evaluate
        model_ft, img_acc_hist, txt_acc_hist, loss_hist = train_model(
            model_ft, data_loader, optimizer, alpha, MAX_EPOCH)
        print('...Training is completed...')

        torch.save(model_ft.state_dict(), 'model/ALGCN_' + dataset + '.pth')

    print('...Evaluation on testing data...')
    view1_feature, view2_feature, view1_predict, view2_predict, classifiers = model_ft(
        torch.tensor(input_data_par['img_test']).cuda(),
        torch.tensor(input_data_par['text_test']).cuda())
    label = input_data_par['label_test']
    view1_feature = view1_feature.detach().cpu().numpy()
    view2_feature = view2_feature.detach().cpu().numpy()
    img_to_txt = fx_calc_map_label(view1_feature, view2_feature, label)
    print('...Image to Text MAP = {}'.format(img_to_txt))

    txt_to_img = fx_calc_map_label(view2_feature, view1_feature, label)
    print('...Text to Image MAP = {}'.format(txt_to_img))

    print('...Average MAP = {}'.format(((img_to_txt + txt_to_img) / 2.)))
示例#4
0
def train_model(model, data_loaders, optimizer, alpha, beta, device="cpu", num_epochs=500):
    since = time.time()
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    test_img_acc_history = []
    test_txt_acc_history = []
    epoch_loss_history =[]

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 20)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                # Set model to training mode
                model.train()
            else:
                # Set model to evaluate mode
                model.eval()

            running_loss = 0.0
            running_corrects_img = 0
            running_corrects_txt = 0
            # Iterate over data.
            for imgs, txts, labels in data_loaders[phase]:
                # imgs = imgs.to(device)
                # txts = txts.to(device)
                # labels = labels.to(device)
                if torch.sum(imgs != imgs)>1 or torch.sum(txts != txts)>1:
                    print("Data contains Nan.")

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if torch.cuda.is_available():
                        imgs = imgs.cuda()
                        txts = txts.cuda()
                        labels = labels.cuda()


                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # Forward
                    view1_feature, view2_feature, view1_predict, view2_predict = model(imgs, txts)

                    loss = calc_loss(view1_feature, view2_feature, view1_predict,
                                     view2_predict, labels, labels, alpha, beta)

                    img_preds = view1_predict
                    txt_preds = view2_predict

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()
                running_corrects_img += torch.sum(torch.argmax(img_preds, dim=1) == torch.argmax(labels, dim=1))
                running_corrects_txt += torch.sum(torch.argmax(txt_preds, dim=1) == torch.argmax(labels, dim=1))

            epoch_loss = running_loss / len(data_loaders[phase].dataset)
            # epoch_img_acc = running_corrects_img.double() / len(data_loaders[phase].dataset)
            # epoch_txt_acc = running_corrects_txt.double() / len(data_loaders[phase].dataset)
            t_imgs, t_txts, t_labels = [], [], []
            with torch.no_grad():
                for imgs, txts, labels in data_loaders['test']:
                    if torch.cuda.is_available():
                            imgs = imgs.cuda()
                            txts = txts.cuda()
                            labels = labels.cuda()
                    t_view1_feature, t_view2_feature, _, _ = model(imgs, txts)
                    t_imgs.append(t_view1_feature.cpu().numpy())
                    t_txts.append(t_view2_feature.cpu().numpy())
                    t_labels.append(labels.cpu().numpy())
            t_imgs = np.concatenate(t_imgs)
            t_txts = np.concatenate(t_txts)
            t_labels = np.concatenate(t_labels).argmax(1)
            img2text = fx_calc_map_label(t_imgs, t_txts, t_labels)
            txt2img = fx_calc_map_label(t_txts, t_imgs, t_labels)

            print('{} Loss: {:.4f} Img2Txt: {:.4f}  Txt2Img: {:.4f}'.format(phase, epoch_loss, img2text, txt2img))

            # deep copy the model
            if phase == 'test' and (img2text + txt2img) / 2. > best_acc:
                best_acc = (img2text + txt2img) / 2.
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'test':
                test_img_acc_history.append(img2text)
                test_txt_acc_history.append(txt2img)
                epoch_loss_history.append(epoch_loss)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best average ACC: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, test_img_acc_history, test_txt_acc_history, epoch_loss_history