Пример #1
0
def train_model(data_loader,
                feature_model,
                model,
                criterion,
                optimizer,
                lr_scheduler,
                num_epochs=25,
                device=None):
    since = time.time()

    model.train()  # Set model to training mode
    loss_list = list()
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        running_loss = 0.0

        # Iterate over data.
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.float().to(device)

            features = feature_model.features(inputs)
            features = torch.flatten(features, 1)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(features)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            lr_scheduler.step()

        epoch_loss = running_loss / data_loader.dataset.__len__()
        loss_list.append(epoch_loss)

        print('{} Loss: {:.4f}'.format(epoch, epoch_loss))

        # 每训练一轮就保存
        util.save_model(model, './models/bbox_regression_%d.pth' % epoch)

    print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    return loss_list
Пример #2
0
def main():
    args, logger = parse_args()

    mention_tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    mention_tokenizer.add_special_tokens({"additional_special_tokens": ["[M]", "[/M]"]})

    index = np.load(args.mention_index)
    mention_dataset = MentionDataset(args.mention_dataset, index, mention_tokenizer, preprocessed=args.mention_preprocessed, return_json=True)
    #mention_dataset = MentionDataset2(args.mention_dataset, mention_tokenizer, preprocessed=args.mention_preprocessed)
    candidate_dataset = CandidateDataset(args.candidate_dataset, mention_tokenizer, preprocessed=args.candidate_preprocessed)

    mention_bert = AutoModel.from_pretrained(args.model_name)
    mention_bert.resize_token_embeddings(len(mention_tokenizer))
    candidate_bert = AutoModel.from_pretrained(args.model_name)

    biencoder = BertBiEncoder(mention_bert, candidate_bert)

    if args.load_model_path:
        biencoder.load_state_dict(torch.load(args.load_model_path))

    model = BertCandidateGenerator(biencoder, device, model_path=args.model_path, use_mlflow=args.mlflow, logger=logger)

    try:
        model.train(
            mention_dataset,
            candidate_dataset,
            inbatch=args.inbatch,
            lr=args.lr,
            batch_size=args.bsz,
            random_bsz=args.random_bsz,
            max_ctxt_len=args.max_ctxt_len,
            max_title_len=args.max_title_len,
            max_desc_len=args.max_desc_len,
            traindata_size=args.traindata_size,
            model_save_interval=args.model_save_interval,
            grad_acc_step=args.gradient_accumulation_steps,
            max_grad_norm=args.max_grad_norm,
            epochs=args.epochs,
            warmup_propotion=args.warmup_propotion,
            fp16=args.fp16,
            fp16_opt_level=args.fp16_opt_level,
            parallel=args.parallel,
            hard_negative=args.hard_negative
        )
    except KeyboardInterrupt:
        pass

    save_model(model.model, args.model_path)
    #torch.save(model.model.state_dict(), args.model_path)

    if args.mlflow:
        mlflow.end_run()
Пример #3
0
def main():
    args = get_arguments()
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if (args.cuda):
        torch.cuda.manual_seed(SEED)
    model, optimizer, training_generator, val_generator, test_generator = initialize(
        args)

    print(model)

    best_pred_loss = 1000.0
    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=2,
                                  min_lr=1e-5,
                                  verbose=True)
    print('Checkpoint folder ', args.save)
    if args.tensorboard:
        writer = SummaryWriter('./runs/' + util.datestr())
    else:
        writer = None
    for epoch in range(1, args.nEpochs + 1):
        train(args, model, training_generator, optimizer, epoch, writer)
        val_metrics, confusion_matrix = validation(args, model, val_generator,
                                                   epoch, writer)

        best_pred_loss = util.save_model(model, optimizer, args, val_metrics,
                                         epoch, best_pred_loss,
                                         confusion_matrix)

        scheduler.step(val_metrics.avg_loss())
Пример #4
0
def main():
    args = get_arguments()
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if (args.cuda):
        torch.cuda.manual_seed(SEED)
    if args.new_training:
        model, optimizer, training_generator, val_generator, class_weight, Last_epoch = initialize_from_saved_model(args)
    else:
        model, optimizer, training_generator, val_generator, class_weight = initialize(args)
        Last_epoch = 0

    #print(model)

    best_pred_loss = 0#lo cambie por balanced accuracy
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=3, min_lr=1e-5, verbose=True)
    print('Checkpoint folder ', args.save)
    # writer = SummaryWriter(log_dir='../runs/' + args.model, comment=args.model)
    for epoch in range(1, args.nEpochs + 1):
        train(args, model, training_generator, optimizer, Last_epoch+epoch, class_weight)
        val_metrics, confusion_matrix = validation(args, model, val_generator, Last_epoch+epoch, class_weight)
        BACC = BalancedAccuray(confusion_matrix.numpy())
        val_metrics.replace({'bacc': BACC})
        best_pred_loss = util.save_model(model, optimizer, args, val_metrics, Last_epoch+epoch, best_pred_loss, confusion_matrix)

        print(confusion_matrix)
        scheduler.step(val_metrics.avg_loss())
Пример #5
0
    def checkpointer(self, epoch, metric):

        is_best = metric < self.mnt_best
        if (is_best):
            self.mnt_best = metric

            self.logger.info(f"Best val loss {self.mnt_best} so far ")
            # else:
            #     self.gradient_accumulation = self.gradient_accumulation // 2
            #     if self.gradient_accumulation < 4:
            #         self.gradient_accumulation = 4

            save_model(self.checkpoint_dir, self.model, self.optimizer,
                       self.valid_metrics.avg('loss'), epoch, f'_model_best')
        save_model(self.checkpoint_dir, self.model, self.optimizer,
                   self.valid_metrics.avg('loss'), epoch, f'_model_last')
Пример #6
0
def train_model(model, data_loader, phases, normal, loss_func, optimizer,
                num_epochs, tensorboard_folder: str, model_folder_name):
    since = time.clock()

    writer = SummaryWriter(tensorboard_folder)
    save_dict, best_rmse, best_test_rmse = {
        'model_state_dict': copy.deepcopy(model.state_dict()),
        'epoch': 0
    }, 999999, 999999
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=.5,
                                                     patience=2,
                                                     threshold=1e-3,
                                                     min_lr=1e-6)

    try:
        for epoch in range(num_epochs):

            running_loss = {phase: 0.0 for phase in phases}
            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, groud_truth, prediction = 0, list(), list()
                tqdm_loader = tqdm(data_loader[phase], phase)

                for x, y in tqdm_loader:
                    # x.to(get_config("device"))
                    # y.to(get_config("device"))
                    with torch.set_grad_enabled(phase == 'train'):
                        y_pred = model(x)
                        loss = loss_func(y_pred, y)

                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                    groud_truth.append(y.cpu().detach().numpy())
                    prediction.append(y_pred.cpu().detach().numpy())

                    running_loss[phase] += loss * y.size(0)
                    steps += y.size(0)

                    tqdm_loader.set_description(
                        f"{phase:8} epoch: {epoch:3}  loss: {running_loss[phase] / steps:3.6},"
                        f"true loss: {normal.rmse_transform(running_loss[phase] / steps):3.6}"
                    )
                    torch.cuda.empty_cache()

                if phase == 'validate' and running_loss[
                        phase] / steps <= best_rmse:
                    best_rmse = running_loss[phase] / steps
                    save_dict.update(model_state_dict=copy.deepcopy(
                        model.state_dict()),
                                     epoch=epoch)
                if phase == 'test' and save_dict['epoch'] == epoch:
                    best_test_rmse = running_loss[phase] / steps

            scheduler.step(running_loss['train'])

            writer.add_scalars(
                f'Loss', {
                    f'{phase} loss':
                    running_loss[phase] / len(data_loader[phase].dataset)
                    for phase in phases
                }, epoch)
    finally:
        time_elapsed = time.clock() - since
        print(
            f"cost time: {time_elapsed:.2} seconds   best val loss: {normal.rmse_transform(best_rmse)}   "
            f"best test loss:{normal.rmse_transform(best_test_rmse)}  best epoch: {save_dict['epoch']}"
        )
        save_model(f"{model_folder_name}", **save_dict)
        # model.load_state_dict(torch.load(f"{model_folder_name}")['model_state_dict'])
        model.load_state_dict(save_dict['model_state_dict'])
    return model
Пример #7
0
def train_model(model: nn.Module, train_data_loader: DataLoader,
                valid_data_loader: DataLoader, loss_func, epochs, optimizer,
                model_folder, tensorboard_folder):
    """
    Args:
        model: nn.Module
        train_data_loader: DataLoader
        valid_data_loader: DataLoader
        loss_func: nn.Module
        epochs: int
        optimizer: Optimizer
        model_folder: str
        tensorboard_folder: str
    """
    warnings.filterwarnings('ignore')

    print(model)
    print(optimizer)

    writer = SummaryWriter(tensorboard_folder)
    writer.add_text('Welcome', 'Welcome to tensorboard!')

    model = convert_to_gpu(model)
    model.train()
    loss_func = convert_to_gpu(loss_func)

    start_time = datetime.datetime.now()

    validate_max_ndcg = 0
    name_list = ["train", "validate"]

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    for epoch in range(epochs):
        loss_dict, metric_dict = {name: 0.0
                                  for name in name_list
                                  }, {name: dict()
                                      for name in name_list}
        data_loader_dic = {
            "train": train_data_loader,
            "validate": valid_data_loader
        }

        for name in name_list:
            # training
            if name == "train":
                model.train()
            # validate
            else:
                model.eval()

            y_true = []
            y_pred = []
            total_loss = 0.0
            tqdm_loader = tqdm(data_loader_dic[name])
            for step, (g, nodes_feature, edges_weight, lengths, nodes,
                       truth_data, users_frequency) in enumerate(tqdm_loader):
                g, nodes_feature, edges_weight, lengths, nodes, truth_data, users_frequency = \
                    convert_all_data_to_gpu(g, nodes_feature, edges_weight, lengths, nodes, truth_data, users_frequency)

                with torch.set_grad_enabled(name == 'train'):
                    # (B, N)
                    output = model(g, nodes_feature, edges_weight, lengths,
                                   nodes, users_frequency)
                    loss = loss_func(output, truth_data.float())
                    total_loss += loss.cpu().data.numpy()
                    if name == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    y_pred.append(output.detach().cpu())
                    y_true.append(truth_data.detach().cpu())
                    tqdm_loader.set_description(
                        f'{name} epoch: {epoch}, {name} loss: {total_loss / (step + 1)}'
                    )

            loss_dict[name] = total_loss / (step + 1)
            y_true = torch.cat(y_true, dim=0)
            y_pred = torch.cat(y_pred, dim=0)

            print(f'{name} metric ...')
            scores = get_metric(y_true=y_true, y_pred=y_pred)
            scores = sorted(scores.items(),
                            key=lambda item: item[0],
                            reverse=False)
            scores = {item[0]: item[1] for item in scores}
            print(json.dumps(scores, indent=4))
            metric_dict[name] = scores

            # save best model
            if name == "validate":
                validate_ndcg_list = []
                for key in metric_dict["validate"]:
                    if key.startswith("ndcg_"):
                        validate_ndcg_list.append(metric_dict["validate"][key])
                validate_ndcg = np.mean(validate_ndcg_list)
                if validate_ndcg > validate_max_ndcg:
                    validate_max_ndcg = validate_ndcg
                    model_path = f"{model_folder}/model_epoch_{epoch}.pkl"
                    save_model(model, model_path)
                    print(f"model save as {model_path}")

        scheduler.step(loss_dict['train'])

        writer.add_scalars(
            'Loss', {f'{name} loss': loss_dict[name]
                     for name in name_list},
            global_step=epoch)

        for metric in metric_dict['train'].keys():
            for name in name_list:
                writer.add_scalars(f'{name} {metric}',
                                   {f'{metric}': metric_dict[name][metric]},
                                   global_step=epoch)

    end_time = datetime.datetime.now()
    print("cost %d seconds" % (end_time - start_time).seconds)
Пример #8
0
def train_model(model: nn.Module, data_loaders: Dict[str, DataLoader],
                loss_func: callable, optimizer: optim,
                model_folder: str, tensorboard_folder: str,
                args, **kwargs):
    num_epochs = args.epochs
    phases = ['train', 'val', 'test']

    writer = SummaryWriter(tensorboard_folder)

    since = time.clock()

    # save_dict, best_rmse = {'model_state_dict': copy.deepcopy(model.state_dict()), 'epoch': 0}, 100000
    save_dict, best_pcc = {'model_state_dict': copy.deepcopy(model.state_dict()), 'epoch': 0}, 0

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.2, patience=5, threshold=1e-3, min_lr=1e-6)

    try:
        for epoch in range(num_epochs):
            running_loss = {phase: 0.0 for phase in phases}
            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, predictions, targets = 0, list(), list()
                tqdm_loader = tqdm(enumerate(data_loaders[phase]))
                for step, (features, truth_data) in tqdm_loader:
                    features = to_var(features, args.device)
                    truth_data = to_var(truth_data, args.device)
                    with torch.set_grad_enabled(phase == 'train'):
                        if args.lossinside:
                            loss, outputs = model(features, truth_data, args, loss_func=loss_func)
                        else:
                            outputs = model(features, args)
                            loss = loss_func(truth=truth_data, predict=outputs)
                        # loss = loss_func(outputs, truth_data)

                        if phase == 'train':
                            if torch.isnan(loss):
                                print("=============LOSS NAN============")
                                print(features)
                                print(truth_data)
                                print(outputs)
                            else:
                                optimizer.zero_grad()
                                loss.backward()
                                optimizer.step()

                    targets.append(truth_data.cpu().numpy())
                    with torch.no_grad():
                        predictions.append(outputs.cpu().detach().numpy())

                    running_loss[phase] += loss * truth_data.size(0)
                    steps += truth_data.size(0)

                    tqdm_loader.set_description(
                        f'{phase} epoch: {epoch}, {phase} loss: {running_loss[phase] / steps}')

                    # For the issue that the CPU memory increases while training. DO NOT know why, but it works.
                    torch.cuda.empty_cache()
                # 性能
                predictions = np.concatenate(predictions)
                targets = np.concatenate(targets)
                # print(2)
                # print(predictions[:3, :3])
                # print(targets[:3, :3])
                scores = calculate_metrics(predictions.reshape(predictions.shape[0], -1),
                                           targets.reshape(targets.shape[0], -1), args, plot=epoch % 5 == 0, **kwargs)
                # print(3)
                writer.add_scalars(f'score/{phase}', scores, global_step=epoch)
                with open(model_folder+"/output.txt", "a") as f:
                    f.write(f'{phase} epoch: {epoch}, {phase} loss: {running_loss[phase] / steps}\n')
                    f.write(str(scores))
                    f.write('\n')
                    f.write(str(time.time()))
                    f.write("\n\n")
                print(scores)
                # if phase == 'val' and scores['RMSE'] < best_rmse:
                if phase == 'val' and scores['pearr'] > best_pcc:
                    best_pcc = scores['pearr']
                    # best_rmse = scores['RMSE']
                    save_dict.update(model_state_dict=copy.deepcopy(model.state_dict()),
                                     epoch=epoch,
                                     optimizer_state_dict=copy.deepcopy(optimizer.state_dict()))

            scheduler.step(running_loss['train'])

            writer.add_scalars('Loss', {
                f'{phase} loss': running_loss[phase] / len(data_loaders[phase].dataset) for phase in phases},
                               global_step=epoch)
    finally:
        time_elapsed = time.clock() - since
        print(f"cost {time_elapsed} seconds")

        save_model(f"{model_folder}/best_model.pkl", **save_dict)
        save_model(f"{model_folder}/final_model.pkl",
                   **{'model_state_dict': copy.deepcopy(model.state_dict()),
                      'epoch': num_epochs,
                      'optimizer_state_dict': copy.deepcopy(optimizer.state_dict())})
def train_my_model(model: nn.Module, data_loader, loss_func: callable, optimizer, num_epochs, model_folder,
                tensorboard_folder: str, **kwargs):
    phases = ['train', 'val', 'test']
    writer = SummaryWriter(tensorboard_folder)
    model = convert_to_gpu(model)
    #model = nn.DataParallel(convert_to_gpu(model), [0, 1])
    loss_func = convert_to_gpu(loss_func)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.1, patience=8, threshold=1e-4, min_lr=1e-6)
    save_dict = {'model_state_dict': copy.deepcopy(model.state_dict()), 'epoch': 0}
    loss_global = 100000
    for epoch in range(num_epochs):
        running_loss = {phase: 0.0 for phase in phases}
        for phase in phases:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            steps, predictions, targets = 0, list(), list()
            tqdm_loaders = tqdm(enumerate(data_loader[phase]))
            for step, (features, truth, covariate) in tqdm_loaders:
                features = convert_to_gpu(features)
                truth = convert_to_gpu(truth)
                covariate = convert_to_gpu(covariate)
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(features, covariate)
                    if not get_Parameter('loss_normalized'):
                        outputs, truth = normalized_transform(outputs, truth, **kwargs)
                    taxi_pickup_loss = loss_func(truth[:, :, :get_Parameter('taxi_size'), 0], outputs[:, :, :get_Parameter('taxi_size'), 0])
                    taxi_dropoff_loss = loss_func(truth[:, :, :get_Parameter('taxi_size'), 1], outputs[:, :, :get_Parameter('taxi_size'), 1])
                    #taxi_loss = loss_func(truth[:, :, :get_Parameter('taxi_size')], outputs[:, :, :get_Parameter('taxi_size')])
                    taxi_loss = taxi_pickup_loss + taxi_dropoff_loss*1.5
                    bike_loss = loss_func(truth[:, :, get_Parameter('taxi_size'):], outputs[:, :, get_Parameter('taxi_size'):])
                    # if epoch<=100:
                    #     loss = (2*taxi_loss + bike_loss)*100
                    # else:
                    #     loss = taxi_loss
                    #loss = taxi_loss + 30*bike_loss
                    loss = (1.5*taxi_loss + bike_loss)*100
                    #loss = loss_func(truth, outputs)
                    #loss = bike_loss
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    if get_Parameter('loss_normalized'):
                        outputs, truth = normalized_transform(outputs, truth, **kwargs)
                targets.append(truth.cpu().numpy())
                with torch.no_grad():
                    predictions.append(outputs.cpu().numpy())
                running_loss[phase] += loss.item()
                steps += truth.size(0)

                tqdm_loaders.set_description(f'{phase} epoch:{epoch}, {phase} loss: {running_loss[phase]/steps}')

            predictions = np.concatenate(predictions)
            targets = np.concatenate(targets)

            scores = calculate_metrics(predictions.reshape(predictions.shape[0], -1),
                                       targets.reshape(targets.shape[0], -1), mode='train', **kwargs)
            print(scores)
            writer.add_scalars(f'score/{phase}', scores, global_step=epoch)
            if phase == 'val' and scores['RMSE'] < loss_global:
                loss_global = scores['RMSE']
                save_dict.update(model_state_dict=copy.deepcopy(model.state_dict()), epoch=epoch,
                                 optimizer_state_dict=copy.deepcopy(optimizer.state_dict()))

        scheduler.step(running_loss['train'])
        writer.add_scalars('Loss', {
            f'{phase} loss': running_loss[phase] for phase in phases
        }, global_step=epoch)

    save_model(f'{model_folder}/best_model.pkl', **save_dict)
    model.load_state_dict(save_dict['model_state_dict'])
    return model
Пример #10
0
def train_model(data_loaders,
                model,
                criterion,
                optimizer,
                lr_scheduler,
                num_epochs=25,
                device=None):
    since = time.time()

    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):

        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # 输出正负样本数
            data_set = data_loaders[phase].dataset
            print('{} - positive_num: {} - negative_num: {} - data size: {}'.
                  format(phase, data_set.get_positive_num(),
                         data_set.get_negative_num(), data_sizes[phase]))

            # Iterate over data.
            for inputs, labels, cache_dicts in data_loaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # print(outputs.shape)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                lr_scheduler.step()

            epoch_loss = running_loss / data_sizes[phase]
            epoch_acc = running_corrects.double() / data_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())

        # 每一轮训练完成后,测试剩余负样本集,进行hard negative mining
        train_dataset = data_loaders['train'].dataset
        remain_negative_list = data_loaders['remain']
        jpeg_images = train_dataset.get_jpeg_images()
        transform = train_dataset.get_transform()

        with torch.set_grad_enabled(False):
            remain_dataset = CustomHardNegativeMiningDataset(
                remain_negative_list, jpeg_images, transform=transform)
            remain_data_loader = DataLoader(remain_dataset,
                                            batch_size=batch_total,
                                            num_workers=8,
                                            drop_last=True)

            # 获取训练数据集的负样本集
            negative_list = train_dataset.get_negatives()
            # 记录后续增加的负样本
            add_negative_list = data_loaders.get('add_negative', [])

            running_corrects = 0
            # Iterate over data.
            for inputs, labels, cache_dicts in remain_data_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = model(inputs)
                # print(outputs.shape)
                _, preds = torch.max(outputs, 1)

                running_corrects += torch.sum(preds == labels.data)

                hard_negative_list, easy_neagtive_list = get_hard_negatives(
                    preds.cpu().numpy(), cache_dicts)
                add_hard_negatives(hard_negative_list, negative_list,
                                   add_negative_list)

            remain_acc = running_corrects.double() / len(remain_negative_list)
            print('remiam negative size: {}, acc: {:.4f}'.format(
                len(remain_negative_list), remain_acc))

            # 训练完成后,重置负样本,进行hard negatives mining
            train_dataset.set_negative_list(negative_list)
            tmp_sampler = CustomBatchSampler(train_dataset.get_positive_num(),
                                             train_dataset.get_negative_num(),
                                             batch_positive, batch_negative)
            data_loaders['train'] = DataLoader(train_dataset,
                                               batch_size=batch_total,
                                               sampler=tmp_sampler,
                                               num_workers=8,
                                               drop_last=True)
            data_loaders['add_negative'] = add_negative_list
            # 重置数据集大小
            data_sizes['train'] = len(tmp_sampler)

        # 每训练一轮就保存
        save_model(model, 'models/linear_svm_alexnet_car_%d.pth' % epoch)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_weights)
    return model
Пример #11
0
def train_decompose(model: nn.Module,
                    dataloaders,
                    optimizer,
                    scheduler,
                    folder: str,
                    trainer,
                    tensorboard_floder,
                    epochs: int,
                    device,
                    max_grad_norm: float = None,
                    early_stop_steps: float = None):
    # dataloaders = get_dataloaders(datasets, batch_size)
    # scaler = ZScoreScaler(datasets['train'].mean[0], datasets['train'].std[0])

    save_path = os.path.join(folder, 'best_model.pkl')

    if os.path.exists(save_path):
        save_dict = torch.load(save_path)

        model.load_state_dict(save_dict['model_state_dict'])
        optimizer.load_state_dict(save_dict['optimizer_state_dict'])

        best_val_loss = save_dict['best_val_loss']
        begin_epoch = save_dict['epoch'] + 1
    else:
        save_dict = dict()
        best_val_loss = float('inf')
        begin_epoch = 0

    phases = ['train', 'validate', 'test']

    writer = SummaryWriter(tensorboard_floder)

    since = time.perf_counter()

    model = model.to(device)

    print(model)
    print(f'Trainable parameters: {get_number_of_parameters(model)}.')

    try:
        for epoch in range(begin_epoch, begin_epoch + epochs):

            running_loss, running_metrics = defaultdict(float), dict()
            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, predictions, running_targets = 0, list(), list()
                tqdm_loader = tqdm(enumerate(dataloaders[phase]))
                for step, (inputs, targets) in tqdm_loader:
                    running_targets.append(targets.numpy())

                    with torch.no_grad():
                        # inputs[..., 0] = scaler.transform(inputs[..., 0])
                        inputs = inputs.to(device)
                        # targets[..., 0] = scaler.transform(targets[..., 0])
                        targets = targets.to(device)

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs, loss = trainer.train(inputs, targets, phase)

                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            if max_grad_norm is not None:
                                nn.utils.clip_grad_norm_(
                                    model.parameters(), max_grad_norm)
                            optimizer.step()

                    with torch.no_grad():
                        # predictions.append(scaler.inverse_transform(outputs).cpu().numpy())
                        predictions.append(outputs.cpu().numpy())
                    running_loss[phase] += loss * len(targets)
                    steps += len(targets)

                    tqdm_loader.set_description(
                        f'{phase:5} epoch: {epoch:3}, {phase:5} loss: {running_loss[phase] / steps:3.6}'
                    )

                    # For the issue that the CPU memory increases while training. DO NOT know why, but it works.
                    torch.cuda.empty_cache()
                # 性能
                # running_metrics[phase] = trainer.loss(torch.cat(predictions), torch.cat(running_targets)).cpu().numpy()

                if phase == 'validate':
                    if running_loss['validate'] < best_val_loss:
                        best_val_loss = running_loss['validate']
                        save_dict.update(model_state_dict=copy.deepcopy(
                            model.state_dict()),
                                         epoch=epoch,
                                         best_val_loss=best_val_loss,
                                         optimizer_state_dict=copy.deepcopy(
                                             optimizer.state_dict()))
                        print(f'Better model at epoch {epoch} recorded.')
                    elif epoch - save_dict['epoch'] > early_stop_steps:
                        raise ValueError('Early stopped.')

            scheduler.step(running_loss['train'])

            # for metric in running_metrics['train'].keys():
            #     for phase in phases:
            #         for key, val in running_metrics[phase][metric].items():
            #             writer.add_scalars(f'{metric}/{key}', {f'{phase}': val}, global_step=epoch)
            # writer.add_scalars('Loss', {
            #     f'{phase} loss': running_loss[phase] / len(dataloaders[phase].dataset) for phase in phases},
            #                    global_step=epoch)
    except (ValueError, KeyboardInterrupt):
        time_elapsed = time.perf_counter() - since
        print(f"cost {time_elapsed} seconds")

        model.load_state_dict(save_dict['model_state_dict'])
        print(save_path)
        save_model(save_path, **save_dict)
        print(
            f'model of epoch {save_dict["epoch"]} successfully saved at `{save_path}`'
        )

    return model
Пример #12
0
    def train(self,
              mention_dataset,
              candidate_dataset,
              inbatch=True,
              lr=1e-5,
              batch_size=32,
              random_bsz=100000,
              max_ctxt_len=32,
              max_title_len=50,
              max_desc_len=100,
              traindata_size=1000000,
              model_save_interval=10000,
              grad_acc_step=1,
              max_grad_norm=1.0,
              epochs=1,
              warmup_propotion=0.1,
              fp16=False,
              fp16_opt_level=None,
              parallel=False,
              hard_negative=False,
             ):


        if inbatch:

            optimizer = optim.Adam(self.model.parameters(), lr=lr)
            scheduler = get_scheduler(
                batch_size, grad_acc_step, epochs, warmup_propotion, optimizer, traindata_size)

            if fp16:
                assert fp16_opt_level is not None
                self.model, optimizer = to_fp16(self.model, optimizer, fp16_opt_level)

            if parallel:
                self.model = to_parallel(self.model)

            for e in range(epochs):
                #mention_batch = mention_dataset.batch(batch_size=batch_size, random_bsz=random_bsz, max_ctxt_len=max_ctxt_len)
                dataloader = DataLoader(mention_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn_json, num_workers=2)
                bar = tqdm(total=traindata_size)
                #for step, (input_ids, labels) in enumerate(mention_batch):
                for step, (input_ids, labels, lines) in enumerate(dataloader):
                    if self.logger:
                        self.logger.debug("%s step", step)
                        self.logger.debug("%s data in batch", len(input_ids))
                        self.logger.debug("%s unique labels in %s labels", len(set(labels)), len(labels))

                    inputs = pad_sequence([torch.LongTensor(token)
                                          for token in input_ids], padding_value=0).t().to(self.device)
                    input_mask = inputs > 0

                    mention_reps = self.model(inputs, input_mask, is_mention=True)

                    pages = list(labels[:])
                    if hard_negative:
                        for label, line in zip(labels, lines):
                            for i in line["nearest_neighbors"]:
                                if str(i) == label:
                                    break
                                pages.append(str(i))

                    candidate_input_ids = candidate_dataset.get_pages(pages, max_title_len=max_title_len, max_desc_len=max_desc_len)
                    candidate_inputs = pad_sequence([torch.LongTensor(token)
                                                    for token in candidate_input_ids], padding_value=0).t().to(self.device)
                    candidate_mask = candidate_inputs > 0
                    candidate_reps = self.model(candidate_inputs, candidate_mask, is_mention=False)

                    scores = mention_reps.mm(candidate_reps.t())
                    accuracy = self.calculate_inbatch_accuracy(scores)

                    target = torch.LongTensor(torch.arange(scores.size(0))).to(self.device)
                    loss = F.cross_entropy(scores, target, reduction="mean")

                    if self.logger:
                        self.logger.debug("Accurac: %s", accuracy)
                        self.logger.debug("Train loss: %s", loss.item())


                    if fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()


                    if (step + 1) % grad_acc_step == 0:
                        if fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), max_grad_norm
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                self.model.parameters(), max_grad_norm
                            )
                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()

                        if self.logger:
                            self.logger.debug("Back propagation in step %s", step+1)
                            self.logger.debug("LR: %s", scheduler.get_lr())

                    if self.use_mlflow:
                        mlflow.log_metric("train loss", loss.item(), step=step)
                        mlflow.log_metric("accuracy", accuracy, step=step)

                    if self.model_path is not None and step % model_save_interval == 0:
                        #torch.save(self.model.state_dict(), self.model_path)
                        save_model(self.model, self.model_path)

                    bar.update(len(input_ids))
                    bar.set_description(f"Loss: {loss.item()}, Accuracy: {accuracy}")
Пример #13
0
def train_model(model: nn.Module,
                data_loaders: Dict[str, DataLoader],
                loss_func: callable,
                optimizer,
                model_folder: str,
                tensorboard_folder: str,
                pid: int):

    phases = ['train', 'validate', 'test']

    writer = SummaryWriter(tensorboard_folder)
    num_epochs = get_attribute('epochs')

    since = time.perf_counter()

    model = convert_to_gpu(model)
    loss_func = convert_to_gpu(loss_func)

    save_dict, best_f1_score = {'model_state_dict': copy.deepcopy(model.state_dict()), 'epoch': 0}, 0

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.5, patience=2, threshold=1e-3, min_lr=1e-6)
    test_metric = None
    try:
        for epoch in range(num_epochs):

            running_loss, running_metrics = {phase: 0.0 for phase in phases}, {phase: dict() for phase in phases}
            save_validate_this_epoch = False
            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, predictions, targets = 0, list(), list()
                tqdm_loader = tqdm(enumerate(data_loaders[phase]))
                for step, (g, spatial_features, temporal_features, external_features, truth_data) in tqdm_loader:

                    if list(external_features.size())[0] != get_attribute("batch_size"):
                        continue

                    if not get_attribute("use_spatial_features"):
                        torch.zero_(spatial_features)
                    if not get_attribute("use_temporal_features"):
                        torch.zero_(temporal_features)
                    if not get_attribute("use_external_features"):
                        torch.zero_(external_features)

                    features, truth_data = convert_train_truth_to_gpu(
                        [spatial_features, temporal_features, external_features], truth_data)

                    with torch.set_grad_enabled(phase == 'train'):
                        _outputs = model(g, *features)
                        outputs = torch.squeeze(_outputs)  # squeeze [batch-size, 1] to [batch-size]
                        loss = loss_func(truth=truth_data, predict=outputs)
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                    targets.append(truth_data.cpu().numpy())
                    with torch.no_grad():
                        predictions.append(outputs.cpu().detach().numpy())

                    running_loss[phase] += loss * truth_data.size(0)
                    steps += truth_data.size(0)

                    tqdm_loader.set_description(
                        f'{pid:2} pid: {phase:8} epoch: {epoch:3}, {phase:8} loss: {running_loss[phase] / steps:3.6}')

                    # For the issue that the CPU memory increases while training. DO NOT know why, but it works.
                    torch.cuda.empty_cache()

                print(f'{phase} metric ...')
                _cp = np.concatenate(predictions)
                _ct = np.concatenate(targets)
                scores = evaluate(_cp, _ct)
                running_metrics[phase] = scores
                print(scores)

                if phase == 'validate' and scores['F1-SCORE'] > best_f1_score:
                    best_f1_score = scores['F1-SCORE']
                    save_validate_this_epoch = True
                    save_dict.update(model_state_dict=copy.deepcopy(model.state_dict()),
                                     epoch=epoch,
                                     optimizer_state_dict=copy.deepcopy(optimizer.state_dict()))
                    print(f"save model as {model_folder}/model_{epoch}.pkl")
                    save_model(f"{model_folder}/model_{epoch}.pkl", **save_dict)

            scheduler.step(running_loss['train'])

            if save_validate_this_epoch:
                test_metric = running_metrics["test"].copy()

            for metric in running_metrics['train'].keys():
                writer.add_scalars(metric, {
                    f'{phase} {metric}': running_metrics[phase][metric] for phase in phases},
                                   global_step=epoch)
            writer.add_scalars('Loss', {
                f'{phase} loss': running_loss[phase] / len(data_loaders[phase].dataset) for phase in phases},
                               global_step=epoch)
    finally:

        time_elapsed = time.perf_counter() - since
        print(f"cost {time_elapsed} seconds")

        save_model(f"{model_folder}/best_model.pkl", **save_dict)

    return test_metric
Пример #14
0
def train(args):
    #prepare_data_loaders(configs)
    cuda = torch.device('cuda')
    step = 0
    loss_list = list()

    plt.rc('font', family='Malgun Gothic')
    model = Model(configs).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    if args.load != '':
        model, optimizer, step = util.load_model(args.load, model, optimizer)

    util.mkdir(args.save)

    for epoch in range(100):

        train_data_loader, valid_data_loader = get_data_loaders(configs)

        for i, data in tqdm(enumerate(train_data_loader),
                            total=int(
                                len(train_data_loader.dataset) /
                                train_data_loader.batch_size)):
            step += 1
            path_list, mel_batch, encoded_batch, text_list, mel_length_list, encoded_length_list = data

            # mel_out, stop_tokens = model(torch.tensor(encoded_batch), torch.tensor(mel_batch))
            mel_out, stop_tokens, enc_attention, dec_attention = model(
                encoded_batch, mel_batch)
            loss = nn.L1Loss()(mel_out.cuda(), mel_batch.cuda())
            loss_list.append(loss.item())
            if step % LOGGING_STEPS == 0:
                writer.add_scalar('loss', np.mean(loss_list), step)
                writer.add_text('script', text_list[0], step)
                # writer.add_image('mel_in', torch.transpose(mel_batch[:1], 1, 2), step)  # (1, 80, T)
                # writer.add_image('mel_out', torch.transpose(mel_out[:1], 1, 2), step)  # (1, 80, T)
                #attention_image = matrix_to_plt_image(enc_attention[0].cpu().detach().numpy().T, text_list[0])
                #writer.add_image('attention', attention_image, step, dataformats="HWC")
                for i, prob in enumerate(enc_attention):

                    for j in range(4):
                        x = torchvision.utils.make_grid(prob[j * 4] * 255)
                        writer.add_image('ENC_Attention_%d_0' % step, x, step)
                print(dec_attention[0].shape)
                for i, prob in enumerate(dec_attention):
                    for j in range(4):
                        x = torchvision.utils.make_grid(prob[j * 4] * 255)
                        writer.add_image('DEC_Attention_%d_0' % step, x, step)

                image = matrix_to_plt_image(
                    mel_batch[0].cpu().detach().numpy().T, text_list[0])
                writer.add_image('mel_in', image, step,
                                 dataformats="HWC")  # (1, 80, T)

                image = matrix_to_plt_image(
                    mel_out[0].cpu().detach().numpy().T, text_list[0])
                writer.add_image('mel_out', image, step,
                                 dataformats="HWC")  # (1, 80, T)

                # print(torch.min(mel_batch), torch.max(mel_batch))
                # print(torch.min(mel_out), torch.max(mel_out))

                # AssertionError: size of input tensor and input format are different.
                # tensor shape: (578, 80), input_format: CHW

                # print(mel_batch.shape)
                # print(mel_out.shape)            # torch.Size([24, 603, 80])   # B = 24
                # print(attn_dot_list[0].shape)   # torch.Size([96, 603, 603])  # 96 = B * 4 (num att. heads)
                # print(attn_dec_list[0].shape)   # torch.Size([96, 603, 603])
                # https://tutorials.pytorch.kr/intermediate/tensorboard_tutorial.html
                # https://www.tensorflow.org/tensorboard/image_summaries

                util.save_model(model, optimizer, args.save, step)
                loss_list = list()
            # print(nn.L1Loss()(mel_out.cuda(), mel_batch.cuda()).item())
            optimizer.zero_grad()
            loss.backward()

            # YUNA! Do not miss the gradient update!
            # https://tutorials.pytorch.kr/beginner/pytorch_with_examples.html
            optimizer.step()

            # break

        loss_list_test = list()
        for i, data in tqdm(enumerate(valid_data_loader),
                            total=int(
                                len(valid_data_loader.dataset) /
                                valid_data_loader.batch_size)):
            path_list, mel_batch, encoded_batch, text_list, mel_length_list, encoded_length_list = data
            mel_out, stop_tokens = model(encoded_batch, mel_batch)
            loss = nn.L1Loss()(mel_out.cuda(), mel_batch.cuda())
            loss_list_test.append(loss.item())

        writer.add_scalar('loss_valid', np.mean(loss_list), step)
        writer.add_text('script_valid', text_list[0], step)

        image = matrix_to_plt_image(mel_batch[0].cpu().detach().numpy().T,
                                    text_list[0])
        writer.add_image('mel_in_valid', image, step,
                         dataformats="HWC")  # (1, 80, T)

        image = matrix_to_plt_image(mel_out[0].cpu().detach().numpy().T,
                                    text_list[0])
        writer.add_image('mel_out_valid', image, step,
                         dataformats="HWC")  # (1, 80, T)

        loss_list_test = list()
        for i, data in tqdm(enumerate(valid_data_loader),
                            total=int(
                                len(valid_data_loader.dataset) /
                                valid_data_loader.batch_size)):
            path_list, mel_batch, encoded_batch, text_list, mel_length_list, encoded_length_list = data
            zero_batch = torch.zeros_like(mel_batch)
            mel_out, stop_tokens = model(encoded_batch, zero_batch)
            loss = nn.L1Loss()(mel_out.cuda(), mel_batch.cuda())
            loss_list_test.append(loss.item())

        writer.add_scalar('loss_infer', np.mean(loss_list), step)
        writer.add_text('script_infer', text_list[0], step)

        image = matrix_to_plt_image(mel_batch[0].cpu().detach().numpy().T,
                                    text_list[0])
        writer.add_image('mel_in_infer', image, step,
                         dataformats="HWC")  # (1, 80, T)

        image = matrix_to_plt_image(mel_out[0].cpu().detach().numpy().T,
                                    text_list[0])
        writer.add_image('mel_out_infer', image, step,
                         dataformats="HWC")  # (1, 80, T)

        # break

    # torch.save(model, PATH)

    return
Пример #15
0
    def train(
        self,
        mention_dataset,
        candidate_dataset,
        lr=1e-5,
        max_ctxt_len=32,
        max_title_len=50,
        max_desc_len=100,
        traindata_size=1000000,
        model_save_interval=10000,
        grad_acc_step=1,
        max_grad_norm=1.0,
        epochs=1,
        warmup_propotion=0.1,
        fp16=False,
        fp16_opt_level=None,
        parallel=False,
    ):

        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        scheduler = get_scheduler(1, grad_acc_step, epochs, warmup_propotion,
                                  optimizer, traindata_size)

        if fp16:
            assert fp16_opt_level is not None
            self.model, optimizer = to_fp16(self.model, optimizer,
                                            fp16_opt_level)

        if parallel:
            self.model = to_parallel(self.model)

        for e in range(epochs):
            dataloader = DataLoader(mention_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    collate_fn=my_collate_fn_json,
                                    num_workers=2)
            bar = tqdm(total=traindata_size)
            for step, (input_ids, labels, lines) in enumerate(dataloader):
                if step > traindata_size:
                    break

                if self.logger:
                    self.logger.debug("%s step", step)
                    self.logger.debug("%s data in batch", len(input_ids))
                    self.logger.debug("%s unique labels in %s labels",
                                      len(set(labels)), len(labels))

                pages = list(labels)
                for nn in lines[0]["nearest_neighbors"]:
                    if nn not in pages:
                        pages.append(str(nn))
                candidate_input_ids = candidate_dataset.get_pages(
                    pages,
                    max_title_len=max_title_len,
                    max_desc_len=max_desc_len)

                inputs = self.merge_mention_candidate(input_ids[0],
                                                      candidate_input_ids)

                inputs = pad_sequence(
                    [torch.LongTensor(token) for token in inputs],
                    padding_value=0).t().to(self.device)
                input_mask = inputs > 0
                scores = self.model(inputs, input_mask)

                target = torch.LongTensor([0]).to(self.device)
                loss = F.cross_entropy(scores.unsqueeze(0),
                                       target.unsqueeze(0),
                                       reduction="mean")

                if self.logger:
                    self.logger.debug("Train loss: %s", loss.item())

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if (step + 1) % grad_acc_step == 0:
                    if fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                    if self.logger:
                        self.logger.debug("Back propagation in step %s",
                                          step + 1)
                        self.logger.debug("LR: %s", scheduler.get_lr())

                if self.use_mlflow:
                    mlflow.log_metric("train loss", loss.item(), step=step)

                if self.model_path is not None and step % model_save_interval == 0:
                    #torch.save(self.model.state_dict(), self.model_path)
                    save_model(self.model, self.model_path)

                bar.update(len(input_ids))
                bar.set_description(f"Loss: {loss.item()}")
def main():
    global args
    args = parse_args(sys.argv[1])

    # -------------------- default arg settings for this model --------------------
    # num of display cols for visdom used during training time
    # real, fake1, fake2, recon1, recon2, wr1, wr2
    args.display_ncols = 6

    # define norm, either a single norm or define normalization method for each network separately
    if hasattr(args, 'norm'):
        args.normD = args.norm
        args.normQ = args.norm
        args.normG = args.norm

    args.gpu_ids = list(
        range(len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))))
    args.device = torch.device('cuda:0')

    args.timestamp = time.strftime(
        '%m%d%H%M%S', time.localtime())  # add timestamp to ckpt_dir
    args.ckpt_dir += '_' + args.timestamp

    # ================================================================================
    # define args before logging args
    # -------------------- init ckpt_dir, logging --------------------
    os.makedirs(args.ckpt_dir, mode=0o777, exist_ok=True)

    # init visu
    visualizer = Visualizer(args)

    # log all the settings
    visualizer.logger.log('sys.argv:\n' + ' '.join(sys.argv))
    for arg in sorted(vars(args)):
        visualizer.logger.log('{:20s} {}'.format(arg, getattr(args, arg)))
    visualizer.logger.log('')

    # -------------------- code copy --------------------
    # copy config yaml
    shutil.copyfile(sys.argv[1],
                    osp.join(args.ckpt_dir, osp.basename(sys.argv[1])))

    # TODO: delete after clean up!
    repo_basename = osp.basename(osp.dirname(osp.abspath(__file__)))
    repo_path = osp.join(args.ckpt_dir, repo_basename)
    os.makedirs(repo_path, mode=0o777, exist_ok=True)

    walk_res = os.walk('.')
    useful_paths = [
        path for path in walk_res
        if '.git' not in path[0] and 'checkpoints' not in path[0]
        and 'configs' not in path[0] and '__pycache__' not in path[0]
        and 'tee_dir' not in path[0] and 'tmp' not in path[0]
    ]
    for p in useful_paths:
        for item in p[-1]:
            if not (item.endswith('.py') or item.endswith('.c')
                    or item.endswith('.h') or item.endswith('.md')):
                continue
            old_path = osp.join(p[0], item)
            new_path = osp.join(repo_path, p[0][2:], item)
            basedir = osp.dirname(new_path)
            os.makedirs(basedir, mode=0o777, exist_ok=True)
            shutil.copyfile(old_path, new_path)

    # -------------------- dataset & loader --------------------
    train_dataset = datasets.__dict__[args.dataset](
        train=True,
        transform=transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                                 inplace=True)
        ]),
        args=args)
    visualizer.logger.log('train_dataset: ' + str(train_dataset))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) %
                                                (2**32)))

    # change test html / ckpt saving frequency
    args.html_iter_freq = len(train_loader) // args.html_per_train_epoch
    visualizer.logger.log('change args.html_iter_freq to %s' %
                          args.html_iter_freq)
    args.save_iter_freq = len(train_loader) // args.html_per_train_epoch
    visualizer.logger.log('change args.save_iter_freq to %s' %
                          args.html_iter_freq)

    val_dataset = datasets.__dict__[args.dataset](
        train=False,
        transform=transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        args=args)
    visualizer.logger.log('val_dataset: ' + str(val_dataset))

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.test_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) %
                                                (2**32)))

    # ================================================================================
    # -------------------- create model --------------------
    model_dict = dict()
    model_dict['D_nets'] = []
    model_dict['G_nets'] = []

    # D, Q
    assert args.which_model_netD == "multiscale_separated"
    D_names = ['M', 'R', 'WR']
    infogan_func = models.define_infoGAN
    model_dict['D'], model_dict['Q'] = infogan_func(
        args.output_nc,
        args.ndf,
        args.which_model_netD,
        args.n_layers_D,
        args.n_layers_Q,
        16,  # num_dis_classes: since we group 4 binary bits, 2^4 = 16.
        args.passwd_length // 4,
        args.normD,
        args.normQ,
        args.init_type,
        args.init_gain,
        D_names=D_names)

    model_dict['G_nets'].append(model_dict['Q'])
    model_dict['D_nets'].append(model_dict['D'])

    # G
    model_dict['G'] = models.define_G(
        args.input_nc + args.passwd_length,
        args.output_nc,
        args.ngf,
        args.which_model_netG,
        args.n_downsample_G,
        args.normG,
        args.dropout,
        args.init_type,
        args.init_gain,
        args.passwd_length,
        use_leaky=args.use_leakyG,
        use_resize_conv=args.use_resize_conv,
        padding_type=args.padding_type,
    )
    model_dict['G_nets'].append(model_dict['G'])

    # FR
    netFR = models.sphere20a(feature=args.feature_layer)
    netFR = torch.nn.DataParallel(netFR).cuda()
    netFR.module.load_state_dict(
        torch.load('./pretrained_models/sphere20a_20171020.pth',
                   map_location='cpu'))
    model_dict['FR'] = netFR
    model_dict['D_nets'].append(netFR)

    # log all the models
    visualizer.logger.log('model_dict')
    for k, v in model_dict.items():
        visualizer.logger.log(k + ':')
        if isinstance(v, list):
            visualizer.logger.log('list, len: ' + str(len(v)))
            for item in v:
                visualizer.logger.log(item.module.__class__.__name__, end=' ')
            visualizer.logger.log('')
        else:
            visualizer.logger.log(v)

    # -------------------- criterions --------------------
    criterion_dict = {
        'GAN':
        models.GANLoss(args.gan_mode).to(args.device),
        'FR':
        models.AngleLoss().to(args.device),
        'L1':
        torch.nn.L1Loss().to(args.device),
        'DIS':
        torch.nn.CrossEntropyLoss().to(args.device),
        'Feat':
        torch.nn.CosineEmbeddingLoss().to(args.device)
        if args.feature_loss == 'cos' else torch.nn.MSELoss().to(args.device)
    }

    # -------------------- optimizers --------------------
    # considering separate optimizer for each network?
    if not hasattr(args, 'G_lr'):
        args.G_lr = args.lr
    if not hasattr(args, 'Q_lr'):
        args.Q_lr = args.lr
    if not hasattr(args, 'D_lr'):
        args.D_lr = args.lr
    if not hasattr(args, 'FR_lr'):
        args.FR_lr = args.lr * 0.1

    optimizer_G_params = [{
        'params': model_dict['G'].parameters(),
        'lr': args.G_lr
    }, {
        'params': model_dict['Q'].parameters(),
        'lr': args.Q_lr
    }]
    optimizer_G = torch.optim.Adam(optimizer_G_params,
                                   betas=(args.beta1, 0.999),
                                   weight_decay=args.weight_decay)

    optimizer_D_params = [{
        'params': model_dict['D'].parameters(),
        'lr': args.D_lr
    }, {
        'params': netFR.parameters(),
        'lr': args.FR_lr
    }]
    optimizer_D = torch.optim.Adam(optimizer_D_params,
                                   betas=(args.beta1, 0.999),
                                   weight_decay=args.weight_decay)

    optimizer_dict = {
        'G': optimizer_G,
        'D': optimizer_D,
    }

    # -------------------- resume --------------------
    if args.resume:
        if osp.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location='cpu')
            # better to restore to the exact iteration. I'm lazy here.
            args.start_epoch = checkpoint['epoch'] + 1

            for name, net in model_dict.items():
                if isinstance(net, list):
                    continue
                if hasattr(args, 'not_resume_models') and (
                        name in args.not_resume_models):
                    continue
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                if 'state_dict_' + name in checkpoint:
                    try:
                        net.load_state_dict(checkpoint['state_dict_' + name])
                    except Exception as e:
                        visualizer.logger.log('fail to load model ' + name +
                                              ' ' + str(e))
                else:
                    visualizer.logger.log('model ' + name +
                                          ' not in checkpoints, just skip')

            if args.resume_optimizer:
                for name, optimizer in optimizer_dict.items():
                    if 'optimizer_' + name in checkpoint:
                        optimizer.load_state_dict(checkpoint['optimizer_' +
                                                             name])
                    else:
                        visualizer.logger.log('optimizer ' + name +
                                              ' not in checkpoints, just skip')

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
        gc.collect()

    # -------------------- miscellaneous --------------------
    torch.backends.cudnn.enabled = True

    # generated image pool/buffer
    M_pool = ImagePool(args.pool_size)
    R_pool = ImagePool(args.pool_size)
    WR_pool = ImagePool(args.pool_size)

    # generate fixed passwords for test
    fixed_z, _, fixed_rand_z, _, \
    fixed_inv_z, _, fixed_rand_inv_z, _, \
    fixed_rand_inv_2nd_z, _ = generate_code(
        args.passwd_length, args.test_size, args.device,
        inv=True, use_minus_one=args.use_minus_one, gen_random_WR=True)
    print('fixed_z')
    print(fixed_z)
    fixed = {
        'z': fixed_z,
        'rand_z': fixed_rand_z,
        'inv_z': fixed_inv_z,
        'rand_inv_z': fixed_rand_inv_z,
        'rand_inv_2nd_z': fixed_rand_inv_2nd_z
    }
    gc.collect()

    for epoch in range(args.start_epoch, args.num_epochs):
        visualizer.logger.log('epoch ' + str(epoch))
        # turn on train mode
        model_dict['G'].train()
        model_dict['Q'].train()
        model_dict['D'].train()
        model_dict['FR'].train()

        # train
        epoch_start_time = time.time()
        train(train_loader, model_dict, criterion_dict, optimizer_dict, M_pool,
              R_pool, WR_pool, visualizer, epoch, args, val_loader, fixed)
        epoch_time = time.time() - epoch_start_time
        message = 'epoch %s total time %s\n' % (epoch, epoch_time)
        visualizer.logger.log(message)
        gc.collect()

        # save model
        if epoch % args.save_epoch_freq == 0:
            save_model(epoch,
                       model_dict,
                       optimizer_dict,
                       args,
                       iter=len(train_loader),
                       save_sep=True)

        # test model, save to html for visualization
        if epoch % args.html_epoch_freq == 0:
            validate(val_loader,
                     model_dict,
                     visualizer,
                     epoch,
                     args,
                     fixed,
                     iter=len(train_loader))
        gc.collect()
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool,
          recon_pool, WR_pool, visualizer, epoch, args, val_loader, fixed):
    iter_data_time = time.time()

    for i, (img, label, landmarks, img_path) in enumerate(train_loader):
        if img.size(0) != args.batch_size:
            continue

        img_cuda = img.cuda(non_blocking=True)

        if i % args.print_loss_freq == 0:
            iter_start_time = time.time()
            t_data = iter_start_time - iter_data_time

        visualizer.reset()

        # -------------------- forward & get aligned --------------------
        theta = alignment(landmarks)
        grid = torch.nn.functional.affine_grid(
            theta, torch.Size((args.batch_size, 3, 112, 96)))

        # -------------------- generate password --------------------
        z, dis_target, rand_z, rand_dis_target, \
        inv_z, inv_dis_target, rand_inv_z, rand_inv_dis_target, \
        rand_inv_2nd_z, rand_inv_2nd_dis_target = generate_code(args.passwd_length,
                                                                args.batch_size,
                                                                args.device,
                                                                inv=True,
                                                                use_minus_one=args.use_minus_one,
                                                                gen_random_WR=True)
        real_aligned = grid_sample(img_cuda, grid)  # (B, 3, h, w)
        real_aligned = real_aligned[:, [2, 1, 0], ...]

        fake = model_dict['G'](img, z.cpu())
        fake_aligned = grid_sample(fake, grid)
        fake_aligned = fake_aligned[:, [2, 1, 0], ...]

        recon = model_dict['G'](fake, inv_z)
        recon_aligned = grid_sample(recon, grid)
        recon_aligned = recon_aligned[:, [2, 1, 0], ...]

        rand_fake = model_dict['G'](img, rand_z.cpu())
        rand_fake_aligned = grid_sample(rand_fake, grid)
        rand_fake_aligned = rand_fake_aligned[:, [
            2,
            1,
            0,
        ], ...]

        rand_recon = model_dict['G'](fake, rand_inv_z)
        rand_recon_aligned = grid_sample(rand_recon, grid)
        rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...]

        rand_recon_2nd = model_dict['G'](fake, rand_inv_2nd_z)
        rand_recon_2nd_aligned = grid_sample(rand_recon_2nd, grid)
        rand_recon_2nd_aligned = rand_recon_2nd_aligned[:, [2, 1, 0], ...]

        # init loss dict for plot & print
        current_losses = {}

        # -------------------- D PART --------------------
        # init
        set_requires_grad(model_dict['G_nets'], False)
        set_requires_grad(model_dict['D_nets'], True)
        optimizer_dict['D'].zero_grad()
        loss_D = 0.

        # ========== Face Recognition (FR) losses (L_{adv}, L_{rec\_cls}) ==========
        # FAKE FRs
        # M
        id_fake = model_dict['FR'](fake_aligned.detach())[0]
        loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device))

        # R & WR
        id_recon = model_dict['FR'](recon_aligned.detach())[0]
        loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to(
            args.device))

        id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0]
        loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon,
                                                    label.to(args.device))

        loss_D_FR_fake_total = args.lambda_FR_M * loss_D_FR_fake + loss_D_FR_recon \
                               + args.lambda_FR_WR * loss_D_FR_rand_recon
        loss_D_FR_fake_avg = loss_D_FR_fake_total / float(1. +
                                                          args.lambda_FR_M +
                                                          args.lambda_FR_WR)
        current_losses.update({
            'D_FR_M': loss_D_FR_fake.item(),
            'D_FR_R': loss_D_FR_recon.item(),
            'D_FR_WR': loss_D_FR_rand_recon.item(),
        })

        # REAL FR
        id_real = model_dict['FR'](real_aligned)[0]
        loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device))

        loss_D += args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5
        current_losses.update({
            'D_FR_real': loss_D_FR_real.item(),
            'D_FR_fake': loss_D_FR_fake_avg.item()
        })

        # ========== GAN loss (L_{GAN}) ==========
        # fake
        all_M = torch.cat((
            fake.detach().cpu(),
            rand_fake.detach().cpu(),
        ), 0)
        pred_D_M = model_dict['D'](fake_pool.query(all_M,
                                                   batch_size=args.batch_size),
                                   'M')
        loss_D_M = criterion_dict['GAN'](pred_D_M, False)

        # R
        pred_D_recon = model_dict['D'](recon_pool.query(
            recon.detach().cpu(), batch_size=args.batch_size), 'R')
        loss_D_recon = criterion_dict['GAN'](pred_D_recon, False)

        # WR
        all_WR = torch.cat(
            (rand_recon.detach().cpu(), rand_recon_2nd.detach().cpu()), 0)
        pred_D_WR = model_dict['D'](WR_pool.query(all_WR,
                                                  batch_size=args.batch_size),
                                    'WR')
        loss_D_WR = criterion_dict['GAN'](pred_D_WR, False)

        loss_D_fake_total = args.lambda_GAN_M * loss_D_M + \
                            args.lambda_GAN_recon * loss_D_recon + \
                            args.lambda_GAN_WR * loss_D_WR
        loss_D_fake_total_weights = args.lambda_GAN_M + \
                                    args.lambda_GAN_recon + \
                                    args.lambda_GAN_WR
        loss_D_GAN_fake = loss_D_fake_total / loss_D_fake_total_weights
        current_losses.update({
            'D_GAN_M': loss_D_M.item(),
            'D_GAN_R': loss_D_recon.item(),
            'D_GAN_WR': loss_D_WR.item()
        })

        # real
        pred_D_real_M = model_dict['D'](img, 'M')
        pred_D_real_R = model_dict['D'](img, 'R')
        pred_D_real_WR = model_dict['D'](img, 'WR')

        loss_D_real_M = criterion_dict['GAN'](pred_D_real_M, True)
        loss_D_real_R = criterion_dict['GAN'](pred_D_real_R, True)
        loss_D_real_WR = criterion_dict['GAN'](pred_D_real_WR, True)

        loss_D_GAN_real = (args.lambda_GAN_M * loss_D_real_M +
                           args.lambda_GAN_recon * loss_D_real_R +
                           args.lambda_GAN_WR * loss_D_real_WR) / \
                          (args.lambda_GAN_M +
                           args.lambda_GAN_recon +
                           args.lambda_GAN_WR)

        loss_D += args.lambda_GAN * (loss_D_GAN_fake + loss_D_GAN_real) * 0.5
        current_losses.update({
            'D_GAN_real': loss_D_GAN_real.item(),
            'D_GAN_fake': loss_D_GAN_fake.item()
        })
        current_losses['D'] = loss_D.item()

        # D backward and optimizer steps
        loss_D.backward()
        optimizer_dict['D'].step()

        # -------------------- G PART --------------------
        # init
        set_requires_grad(model_dict['D_nets'], False)
        set_requires_grad(model_dict['G_nets'], True)
        optimizer_dict['G'].zero_grad()
        loss_G = 0

        # ========== GAN loss (L_{GAN}) ==========
        pred_G_fake = model_dict['D'](fake, 'M')
        loss_G_GAN_fake = criterion_dict['GAN'](pred_G_fake, True)

        pred_G_recon = model_dict['D'](recon, 'R')
        loss_G_GAN_recon = criterion_dict['GAN'](pred_G_recon, True)

        pred_G_WR = model_dict['D'](rand_recon, 'WR')
        loss_G_GAN_WR = criterion_dict['GAN'](pred_G_WR, True)

        loss_G_GAN_total = args.lambda_GAN_M * loss_G_GAN_fake + \
                           args.lambda_GAN_recon * loss_G_GAN_recon + \
                           args.lambda_GAN_WR * loss_G_GAN_WR
        loss_G_GAN_total_weights = args.lambda_GAN_M + args.lambda_GAN_recon + args.lambda_GAN_WR
        loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights
        loss_G += args.lambda_GAN * loss_G_GAN

        current_losses.update({
            'G_GAN_M': loss_G_GAN_fake.item(),
            'G_GAN_R': loss_G_GAN_recon.item(),
            'G_GAN_WR': loss_G_GAN_WR.item(),
            'G_GAN': loss_G_GAN.item()
        })

        # ========== infoGAN loss (L_{aux}) ==========
        if args.lambda_dis > 0:
            fake_dis_logits = model_dict['Q'](infoGAN_input(img_cuda, fake))
            infogan_fake_acc = 0
            loss_G_fake_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = fake_dis_logits[dis_idx].max(dim=1)[1]
                b = dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_fake_acc += acc.item()
                loss_G_fake_dis += criterion_dict['DIS'](
                    fake_dis_logits[dis_idx], dis_target[:, dis_idx])
            infogan_fake_acc = infogan_fake_acc / float(
                args.passwd_length // 4)

            recon_dis_logits = model_dict['Q'](infoGAN_input(fake, recon))
            infogan_recon_acc = 0
            loss_G_recon_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = recon_dis_logits[dis_idx].max(dim=1)[1]
                b = inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_recon_acc += acc.item()
                loss_G_recon_dis += criterion_dict['DIS'](
                    recon_dis_logits[dis_idx], inv_dis_target[:, dis_idx])
            infogan_recon_acc = infogan_recon_acc / float(
                args.passwd_length // 4)

            rand_recon_dis_logits = model_dict['Q'](infoGAN_input(
                fake, rand_recon))
            infogan_rand_recon_acc = 0
            loss_G_recon_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = rand_recon_dis_logits[dis_idx].max(dim=1)[1]
                b = rand_inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_rand_recon_acc += acc.item()
                loss_G_recon_rand_dis += criterion_dict['DIS'](
                    rand_recon_dis_logits[dis_idx],
                    rand_inv_dis_target[:, dis_idx])
            infogan_rand_recon_acc = infogan_rand_recon_acc / float(
                args.passwd_length // 4)

            dis_loss_total = loss_G_fake_dis + loss_G_recon_dis + loss_G_recon_rand_dis
            dis_acc_total = infogan_fake_acc + infogan_recon_acc + infogan_rand_recon_acc
            dis_cnt = 3

            loss_G += args.lambda_dis * dis_loss_total
            current_losses.update({
                'dis': dis_loss_total.item(),
                'dis_acc': dis_acc_total / float(dis_cnt)
            })

        # ========== Face Recognition (FR) loss (L_{adv}, L{rec_cls}})==========
        # (netFR must not be fixed)
        id_fake_G, fake_feat = model_dict['FR'](fake_aligned)
        loss_G_FR_fake = -criterion_dict['FR'](id_fake_G, label.to(
            args.device))

        id_recon_G, recon_feat = model_dict['FR'](recon_aligned)
        loss_G_FR_recon = criterion_dict['FR'](id_recon_G,
                                               label.to(args.device))

        id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned)
        loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G,
                                                     label.to(args.device))

        loss_G_FR_avg = (args.lambda_FR_M * loss_G_FR_fake +
                         loss_G_FR_recon +
                         args.lambda_FR_WR * loss_G_FR_rand_recon) /\
                        (args.lambda_FR_M + 1. + args.lambda_FR_WR)
        loss_G += args.lambda_FR * loss_G_FR_avg

        current_losses.update({
            'G_FR_M': loss_G_FR_fake.item(),
            'G_FR_R': loss_G_FR_recon.item(),
            'G_FR_WR': loss_G_FR_rand_recon.item(),
            'G_FR': loss_G_FR_avg.item()
        })

        # ========== Feature losses (L_{feat} is the sum of the three L_{dis}'s) ==========
        if args.feature_loss == 'cos':  # make cos sim target
            FR_cos_sim_target = torch.empty(size=(args.batch_size, 1),
                                            dtype=torch.float32,
                                            device=args.device)
            FR_cos_sim_target.fill_(-1.)
        else:
            FR_cos_sim_target = None

        id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned)
        id_rand_recon_2nd_G, rand_recon_2nd_feat = model_dict['FR'](
            rand_recon_2nd_aligned)

        if args.lambda_Feat:
            loss_G_feat = get_feat_loss(fake_feat, rand_fake_feat,
                                        FR_cos_sim_target, args.feature_loss,
                                        criterion_dict)
            current_losses['G_feat'] = loss_G_feat.item()
        else:
            loss_G_feat = 0.

        if args.lambda_WR_Feat:
            loss_G_WR_feat = get_feat_loss(rand_recon_feat,
                                           rand_recon_2nd_feat,
                                           FR_cos_sim_target,
                                           args.feature_loss, criterion_dict)
            current_losses['G_WR_feat'] = loss_G_WR_feat.item()
        else:
            loss_G_WR_feat = 0.

        if args.lambda_false_recon_diff:
            loss_G_M_WR_feat = get_feat_loss(fake_feat, rand_recon_feat,
                                             FR_cos_sim_target,
                                             args.feature_loss, criterion_dict)
            current_losses['G_feat_M_WR'] = loss_G_M_WR_feat.item()
        else:
            loss_G_M_WR_feat = 0.

        loss_G += args.lambda_Feat * loss_G_feat + \
                  args.lambda_WR_Feat * loss_G_WR_feat + \
                  args.lambda_false_recon_diff * loss_G_M_WR_feat

        # ========== L1/Recon losses (L_1, L_{rec}) ==========
        loss_G_L1 = criterion_dict['L1'](fake, img_cuda)
        loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img_cuda)
        loss_G_recon = criterion_dict['L1'](recon, img_cuda)

        loss_G += args.lambda_L1 * loss_G_L1 + \
                  args.lambda_rand_recon_L1 * loss_G_rand_recon_L1 + \
                  args.lambda_G_recon * loss_G_recon

        current_losses.update({
            'L1_M': loss_G_L1.item(),
            'recon': loss_G_recon.item(),
            'L1_WR': loss_G_rand_recon_L1.item()
        })

        current_losses['G'] = loss_G.item()

        # G backward and optimizer steps
        loss_G.backward()
        optimizer_dict['G'].step()

        # -------------------- LOGGING PART --------------------
        if i % args.print_loss_freq == 0:
            t = (time.time() - iter_start_time) / args.batch_size
            visualizer.print_current_losses(epoch, i, current_losses, t,
                                            t_data)
            if args.display_id > 0 and i % args.plot_loss_freq == 0:
                visualizer.plot_current_losses(epoch,
                                               float(i) / len(train_loader),
                                               args, current_losses)

        if i % args.visdom_visual_freq == 0:
            save_result = i % args.update_html_freq == 0

            current_visuals = OrderedDict()
            current_visuals['real'] = img.detach()
            current_visuals['fake'] = fake.detach()
            current_visuals['rand_fake'] = rand_fake.detach()
            current_visuals['recon'] = recon.detach()
            current_visuals['rand_recon'] = rand_recon.detach()
            current_visuals['rand_recon_2nd'] = rand_recon_2nd.detach()

            try:
                with time_limit(60):
                    visualizer.display_current_results(current_visuals, epoch,
                                                       save_result, args)
            except TimeoutException:
                visualizer.logger.log(
                    'TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1'
                    .format(epoch, i))
                # disable visdom display ever since
                args.display_id = -1

        # +1 so that we do not save/test for 0th iteration
        if (i + 1) % args.save_iter_freq == 0:
            save_model(epoch,
                       model_dict,
                       optimizer_dict,
                       args,
                       iter=i,
                       save_sep=False)
            if args.display_id > 0:
                visualizer.vis.save([args.name])

        if (i + 1) % args.html_iter_freq == 0:
            validate(val_loader, model_dict, visualizer, epoch, args, fixed, i)

        if (i + 1) % args.print_loss_freq == 0:
            iter_data_time = time.time()
Пример #18
0
    for i, (img, img_char) in enumerate(val_dataloader):
        train_step += 1

        hde_arr = np.array([v for k, v in hde_dict_train.items()])
        hde_arr = torch.from_numpy(hde_arr).type(torch.float32)

        img, hde_arr = img.to(device), hde_arr.to(device)

        hde_distance = model(img, hde_arr)

        # loss
        hde_cur = []
        for char in img_char:
            for j, h_a in enumerate(hde_arr.cpu().numpy()):
                if (h_a == hde_dict_train[char].astype(np.float32)).all():
                    hde_cur.append(j)
        loss = util.loss_MCPL(hde_distance, hde_cur)

        acc = util.acc_cal(hde_distance, hde_cur)

        test_loss.append(loss.item())
        test_acc.append(acc)

    with open("result_1.txt","a+") as f:
        print('({:d} / {:d})  val Loss: {:.4f} val Accuracy {}'.format(epoch, start_epoch + epochs, np.mean(test_loss),
                                                                       np.mean(test_acc)), file=f, flush=True)
    print('({:d} / {:d})  val Loss: {:.4f} val Accuracy {}'.format(epoch, start_epoch + epochs, np.mean(test_loss), np.mean(test_acc)))

    util.save_model(model, epoch, scheduler.get_lr(), optimizer)
Пример #19
0
def train_temporal_model(model: nn.Module, data_loader, phases: list,
                         normal: list, channel_list: list, channel: int,
                         loss_func: nn.Module, optimizer, num_epochs: int,
                         tensorboard_folder: str, model_folder_name):
    since = time.clock()
    channel_num = len(channel_list)

    writer = SummaryWriter(tensorboard_folder)
    save_dict, best_rmse, best_test_rmse = [{'model_state_dict': copy.deepcopy(model.state_dict()),
                                             'epoch': 0} for i in range(channel_num)], \
                                           [999999 for i in range(channel_num)], \
                                           [999999 for i in range(channel_num)]
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=.5,
                                                     patience=2,
                                                     threshold=1e-3,
                                                     min_lr=1e-6)

    try:
        for epoch in range(num_epochs):

            running_loss = {
                phase: [0.0 for i in range(channel_num + 1)]
                for phase in phases
            }
            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, groud_truth, prediction = 0, list(), list()
                tqdm_loader = tqdm(data_loader[phase], phase)

                for x, y in tqdm_loader:
                    # x.to(get_config("device"))
                    # y.to(get_config("device"))
                    with torch.set_grad_enabled(phase == 'train'):
                        y_pred = model(x)
                        loss = loss_func(y_pred, y)

                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                    groud_truth.append(y.cpu().detach().numpy())
                    prediction.append(y_pred.cpu().detach().numpy())

                    # running_loss[phase] += loss * y.size(0)
                    running_loss = loss_calculate(y, y_pred, running_loss,
                                                  phase, loss_func,
                                                  channel_num)
                    steps += y.size(0)

                    tqdm_loader.set_description(
                        f"{phase:8} epoch: {epoch:3} loss: {running_loss[phase][channel_num] / steps:3.6},"
                        f"true loss: {normal[channel].rmse_transform(running_loss[phase][channel_num] / steps):3.6}"
                    )
                    torch.cuda.empty_cache()

                for i in range(channel_num):
                    if phase == 'validate' and running_loss[phase][
                            i] / steps <= best_rmse[i]:
                        best_rmse[i] = running_loss['validate'][i] / steps
                        save_dict[i].update(model_state_dict=copy.deepcopy(
                            model.state_dict()),
                                            epoch=epoch)
                    if phase == 'test' and save_dict[i]['epoch'] == epoch:
                        best_test_rmse[i] = running_loss['test'][i] / steps

            scheduler.step(running_loss['train'][channel_num])

            # steps = len(data_loader['train'].dataset)    cuowu
            for i in range(channel_num):
                writer.add_scalars(
                    f'Loss channel{channel_list[i]}', {
                        f'{phase} true loss': normal[i].rmse_transform(
                            running_loss[phase][i] /
                            len(data_loader[phase].dataset))
                        for phase in phases
                    }, epoch)
            writer.add_scalars(
                f'Loss channel total', {
                    f'{phase} loss': normal[channel].rmse_transform(
                        running_loss[phase][channel_num] /
                        len(data_loader[phase].dataset))
                    for phase in phases
                }, epoch)
    finally:
        time_elapsed = time.clock() - since
        print(f"cost time: {time_elapsed:.2} seconds")
        for i in range(channel_num):
            save_model(f"{model_folder_name}_{channel_list[i]}.pkl",
                       **save_dict[i])
Пример #20
0
    def train(self, config, model):
        tb_writer = SummaryWriter(config['tb_writer_dir'])
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

        model.train()
        data_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            drop_last=True,
            num_workers=0)

        emb_layers = nn.ModuleList([
            model.query_embedding, model.name_embedding, model.body_embedding
        ])
        emb_layers_paras = list(map(id, emb_layers.parameters()))
        base_paras = filter(lambda p: id(p) not in emb_layers_paras,
                            model.parameters())
        optimizer_grouped_parameters = [{
            'params': base_paras
        }, {
            'params': emb_layers.parameters(),
            'lr': config['emb_learning_rate']
        }]

        optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                                      lr=config['learning_rate'],
                                      eps=1e-8)

        scheduler = self.get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config['warmup_steps'],
            num_training_steps=len(data_loader) * config['nb_epoch'])

        n_iters = len(data_loader)
        itr_global = config['reload'] + 1

        best_mrr = 0
        for epoch in range(
                int(config['reload'] / n_iters) + 1, config['nb_epoch'] + 1):
            losses = []
            for batch in data_loader:
                model.train()
                batch_gpu = [tensor.to(self.device) for tensor in batch]
                loss = model(*batch_gpu)

                loss.backward()
                optimizer.step()
                scheduler.step()
                model.zero_grad()

                losses.append(loss.item())

                if itr_global % config['log_every'] == 0:
                    info = f'epo:[{epoch}/{config["nb_epoch"]}] itr:[{itr_global % n_iters}/{n_iters}] ' \
                           f'Loss={np.mean(losses)} learning rate: {optimizer.param_groups[0]["lr"]}'
                    logger.info(info)
                    tb_writer.add_scalar('loss', np.mean(losses),
                                         int(itr_global / 1000))
                    tb_writer.add_scalar('learning_rate',
                                         optimizer.param_groups[0]["lr"],
                                         int(itr_global / 1000))
                    losses = []

                itr_global = itr_global + 1

                if itr_global % config['valid_every'] == 0:
                    logger.info("\nvalidating..")
                    mrr = self.test(config, model, self.valid_dataset)
                    if mrr > best_mrr:
                        best_mrr = mrr
                        print(f'best mrr is {best_mrr}, save model...')
                        save_model(model, config['model_save_dir'], itr_global)

                if itr_global % config['save_every'] == 0:
                    save_model(model, config['model_save_dir'], itr_global)

        logger.info('test on eval dataset')
        self.test(config, model, self.eval_dataset)
        save_model(model, config['model_save_dir'],
                   str(itr_global) + "_finnal")
Пример #21
0
    model_path = './models/alexnet_car.pth'
    model = alexnet()
    num_classes = 2
    num_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(num_features, num_classes)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    # 固定特征提取
    for param in model.parameters():
        param.requires_grad = False
    # 创建SVM分类器
    model.classifier[6] = nn.Linear(num_features, num_classes)
    # print(model)
    model = model.to(device)

    criterion = hinge_loss
    # 由于初始训练集数量很少,所以降低学习率
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    # 共训练10轮,每隔4论减少一次学习率
    lr_schduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

    best_model = train_model(data_loaders,
                             model,
                             criterion,
                             optimizer,
                             lr_schduler,
                             num_epochs=10,
                             device=device)
    # 保存最好的模型参数
    save_model(best_model, 'models/best_linear_svm_alexnet_car.pth')
Пример #22
0
def train_model(data_loaders,
                data_sizes,
                model_name,
                model,
                criterion,
                optimizer,
                lr_scheduler,
                num_epochs=25,
                device=None):
    since = time.time()

    best_model_weights = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    loss_dict = {'train': [], 'test': []}
    acc_dict = {'train': [], 'test': []}
    for epoch in range(num_epochs):

        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and test phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in data_loaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if phase == 'train':
                        outputs, aux2, aux1 = model(inputs)

                        # 仅使用最后一个分类器进行预测
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels) + 0.3 * (
                            criterion(aux2, labels) + criterion(aux1, labels))
                    else:
                        outputs = model(inputs)
                        # print(outputs.shape)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                lr_scheduler.step()

            epoch_loss = running_loss / data_sizes[phase]
            epoch_acc = running_corrects.double() / data_sizes[phase]
            loss_dict[phase].append(epoch_loss)
            acc_dict[phase].append(epoch_acc)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            # deep copy the model
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_weights = copy.deepcopy(model.state_dict())

        # 每训练一轮就保存
        util.save_model(model, './data/models/%s_%d.pth' % (model_name, epoch))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_weights)
    return model, loss_dict, acc_dict
Пример #23
0
def train_model(data_loaders,
                data_sizes,
                model_name,
                model,
                criterion,
                optimizer,
                lr_scheduler,
                num_epochs=25,
                device=None):
    since = time.time()

    best_model_weights = copy.deepcopy(model.state_dict())
    best_top1_acc = 0.0
    best_top5_acc = 0.0

    loss_dict = {'train': [], 'test': []}
    top1_acc_dict = {'train': [], 'test': []}
    top5_acc_dict = {'train': [], 'test': []}
    for epoch in range(num_epochs):

        print('{} - Epoch {}/{}'.format(model_name, epoch + 1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and test phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            # running_corrects = 0
            running_top1_acc = 0.0
            running_top5_acc = 0.0

            # Iterate over data.
            for inputs, labels in data_loaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if phase == 'test':
                        N, N_crops, C, H, W = inputs.size()
                        result = model(inputs.view(
                            -1, C, H, W))  # fuse batch size and ncrops
                        outputs = result.view(N, N_crops,
                                              -1).mean(1)  # avg over crops
                    else:
                        outputs = model(inputs)
                    # print(outputs.shape)
                    # _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # compute top-k accuray
                    topk_list = metrics.topk_accuracy(outputs,
                                                      labels,
                                                      topk=(1, 5))
                    running_top1_acc += topk_list[0]
                    running_top5_acc += topk_list[1]

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                # running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                print('lr: {}'.format(optimizer.param_groups[0]['lr']))
                lr_scheduler.step()

            epoch_loss = running_loss / data_sizes[phase]
            epoch_top1_acc = running_top1_acc / len(data_loaders[phase])
            epoch_top5_acc = running_top5_acc / len(data_loaders[phase])

            loss_dict[phase].append(epoch_loss)
            top1_acc_dict[phase].append(epoch_top1_acc)
            top5_acc_dict[phase].append(epoch_top5_acc)

            print('{} Loss: {:.4f} Top-1 Acc: {:.4f} Top-5 Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_top1_acc, epoch_top5_acc))

            # deep copy the model
            if phase == 'test' and epoch_top1_acc > best_top1_acc:
                best_top1_acc = epoch_top1_acc
                best_model_weights = copy.deepcopy(model.state_dict())
            if phase == 'test' and epoch_top5_acc > best_top5_acc:
                best_top5_acc = epoch_top5_acc

        # 每训练10轮保存一次
        if (epoch + 1) % 10 == 0:
            util.save_model(
                model.cpu(),
                '../data/models/%s_%d.pth' % (model_name, epoch + 1))
            model = model.to(device)

    time_elapsed = time.time() - since
    print('Training {} complete in {:.0f}m {:.0f}s'.format(
        model_name, time_elapsed // 60, time_elapsed % 60))
    print('Best test Top-1 Acc: {:4f}'.format(best_top1_acc))
    print('Best test Top-5 Acc: {:4f}'.format(best_top5_acc))

    # load best model weights
    model.load_state_dict(best_model_weights)
    return model, loss_dict, top1_acc_dict, top5_acc_dict
Пример #24
0
def do_train(cfg, model):
    # get criterion -----------------------------
    criterion = criterion_factory.get_criterion(cfg)

    # get optimization --------------------------
    optimizer = optimizer_factory.get_optimizer(model, cfg)

    # initial -----------------------------------
    best = {
        'loss': float('inf'),
        'score': 0.0,
        'epoch': -1,
    }

    # resume model ------------------------------
    if cfg.resume_from:
        log.info('\n')
        log.info(f're-load model from {cfg.resume_from}')
        detail = util.load_model(cfg.resume_from, model, optimizer, cfg.device)
        best.update({
            'loss': detail['loss'],
            'score': detail['score'],
            'epoch': detail['epoch'],
        })

    # scheduler ---------------------------------
    scheduler = scheduler_factory.get_scheduler(cfg, optimizer, best['epoch'])

    # fp16 --------------------------------------
    if cfg.apex:
        amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    # setting dataset ---------------------------
    loader_train = dataset_factory.get_dataloader(cfg.data.train)
    loader_valid = dataset_factory.get_dataloader(cfg.data.valid)

    # start trainging ---------------------------
    start_time = datetime.now().strftime('%Y/%m/%d %H:%M:%S')
    log.info('\n')
    log.info(f'** start train [fold{cfg.fold}th] {start_time} **\n')
    log.info(
        'epoch    iter      rate     | smooth_loss/score | valid_loss/score | best_epoch/best_score |  min'
    )
    log.info(
        '-------------------------------------------------------------------------------------------------'
    )

    for epoch in range(best['epoch'] + 1, cfg.epoch):
        end = time.time()
        util.set_seed(epoch)

        ## train model --------------------------
        train_results = run_nn(cfg.data.train,
                               'train',
                               model,
                               loader_train,
                               criterion=criterion,
                               optimizer=optimizer,
                               apex=cfg.apex,
                               epoch=epoch)

        ## valid model --------------------------
        with torch.no_grad():
            val_results = run_nn(cfg.data.valid,
                                 'valid',
                                 model,
                                 loader_valid,
                                 criterion=criterion,
                                 epoch=epoch)

        detail = {
            'score': val_results['score'],
            'loss': val_results['loss'],
            'epoch': epoch,
        }

        if val_results['loss'] <= best['loss']:
            best.update(detail)
            util.save_model(model, optimizer, detail, cfg.fold[0],
                            os.path.join(cfg.workdir, 'checkpoint'))


        log.info('%5.1f   %5d    %0.6f   |  %0.4f  %0.4f  |  %0.4f  %6.4f |  %6.1f     %6.4f    | %3.1f min' % \
                (epoch+1, len(loader_train), util.get_lr(optimizer), train_results['loss'], train_results['score'], val_results['loss'], val_results['score'], best['epoch'], best['score'], (time.time() - end) / 60))

        scheduler.step(
            val_results['loss'])  # if scheduler is reducelronplateau
        # scheduler.step()

        # early stopping-------------------------
        if cfg.early_stop:
            if epoch - best['epoch'] > cfg.early_stop:
                log.info(f'=================================> early stopping!')
                break
        time.sleep(0.01)
Пример #25
0
        # print(model)
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        lr_schduler = optim.lr_scheduler.StepLR(optimizer,
                                                step_size=8,
                                                gamma=0.96)

        util.check_dir('./data/models/')
        best_model, loss_dict, acc_dict = train_model(data_loaders,
                                                      data_sizes,
                                                      name,
                                                      model,
                                                      criterion,
                                                      optimizer,
                                                      lr_schduler,
                                                      num_epochs=100,
                                                      device=device)
        # 保存最好的模型参数
        util.save_model(best_model, './data/models/best_%s.pth' % name)

        res_loss[name] = loss_dict
        res_acc[name] = acc_dict

        print('train %s done' % name)
        print()

    util.save_png('loss', res_loss)
    util.save_png('acc', res_acc)
Пример #26
0
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool, recon_pool, fake_pair_pool, WR_pool, visualizer, epoch, args, test_loader, fixed_z, fixed_rand_z):
    iter_data_time = time.time()

    for i, (img, label, landmarks, img_path) in enumerate(train_loader):
        iter_start_time = time.time()
        if i % args.print_loss_freq == 0:
            t_data = iter_start_time - iter_data_time

        visualizer.reset()
        batch_size = img.size(0)

        if args.lambda_dis > 0:
            # -------------------- generate password --------------------
            z, dis_target, rand_z, rand_dis_target, inv_z, inv_dis_target, another_rand_z, another_rand_dis_target = generate_code(args.passwd_length, batch_size, args.device, inv=True)

            # -------------------- forward --------------------
            # TODO: whether to detach
            fake = model_dict['G'](img, z.cpu())
            rand_fake = model_dict['G'](img, rand_z.cpu())
            if args.lambda_G_recon > 0:
                recon = model_dict['G'](fake, inv_z)
                rand_recon = model_dict['G'](fake, another_rand_z)
        else:
            fake = model_dict['G'](img)
            if args.lambda_G_recon > 0:
                recon = model_dict['G'](fake)

        # FR forward and FR losses
        theta = alignment(landmarks)
        grid = torch.nn.functional.affine_grid(theta, torch.Size((batch_size, 3, 112, 96)))
        real_aligned = torch.nn.functional.grid_sample(img.cuda(), grid)
        real_aligned = real_aligned[:, [2, 1, 0], ...]

        fake_aligned = torch.nn.functional.grid_sample(fake, grid)
        fake_aligned = fake_aligned[:, [2, 1, 0], ...]

        rand_fake_aligned = torch.nn.functional.grid_sample(rand_fake, grid)
        rand_fake_aligned = rand_fake_aligned[:, [2, 1, 0, ], ...]
        # (B, 3, h, w)

        if args.lambda_G_recon > 0:
            recon_aligned = torch.nn.functional.grid_sample(recon, grid)
            recon_aligned = recon_aligned[:, [2, 1, 0], ...]
            rand_recon_aligned = torch.nn.functional.grid_sample(rand_recon, grid)
            rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...]

        current_losses = {}
        # -------------------- D PART --------------------
        if optimizer_dict['D'] is not None:
            set_requires_grad(model_dict['G_nets'], False)
            set_requires_grad(model_dict['D_nets'], True)
            optimizer_dict['D'].zero_grad()

            id_real = model_dict['FR'](real_aligned)[0]
            loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device))

            cnt_FR_fake = 0.
            loss_D_FR_fake_total = 0
            if args.train_M:
                id_fake = model_dict['FR'](fake_aligned.detach())[0]
                id_rand_fake = model_dict['FR'](rand_fake_aligned.detach())[0]

                loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device))
                loss_D_FR_rand_fake = criterion_dict['FR'](id_rand_fake, label.to(args.device))

                loss_D_FR_fake_total += loss_D_FR_fake + loss_D_FR_rand_fake
                cnt_FR_fake += 2.
                current_losses.update({'D_FR_fake': loss_D_FR_fake.item(),
                                       'D_FR_rand': loss_D_FR_rand_fake.item(),
                                       # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                       })

            if args.recon_FR:
                # TODO: rand_fake_recon FR loss?
                id_recon = model_dict['FR'](recon_aligned.detach())[0]
                loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to(args.device))
                if args.lambda_FR_WR:
                    id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0]
                    loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon, label.to(args.device))
                    current_losses.update({'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                           })
                else:
                    loss_D_FR_rand_recon = 0.

                loss_D_FR_fake_total += loss_D_FR_recon + args.lambda_FR_WR * loss_D_FR_rand_recon
                cnt_FR_fake += 1. + args.lambda_FR_WR
                current_losses.update({'D_FR_recon': loss_D_FR_recon.item(),
                                       # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                                       })


            loss_D_FR_fake_avg = loss_D_FR_fake_total / float(cnt_FR_fake)

            loss_D = args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5
            current_losses.update({'D_FR_real': loss_D_FR_real.item(),
                                   'D_FR_fake': loss_D_FR_fake_avg.item()
                              # 'D_FR_fake': loss_D_FR_fake.item(),
                              # 'D_FR_rand': loss_D_FR_rand_fake.item(),
                              # 'D_FR_rand_recon': loss_D_FR_rand_recon.item()
                              })

            # GAN loss
            if args.lambda_GAN > 0:
                # real
                if args.recon_pair_GAN:
                    assert args.single_GAN_recon_only
                    real_input = torch.cat((img.cuda(), recon.detach()), dim=1)
                else:
                    real_input = img

                pred_D_real = model_dict['D'](real_input)
                loss_D_real = criterion_dict['GAN'](pred_D_real, True)

                # fake
                loss_D_fake_total = 0.
                loss_D_fake_total_weights = 0.

                # recon
                if args.lambda_GAN_recon:
                    if args.recon_pair_GAN:
                        recon_input_to_pool = torch.cat((recon.detach().cpu(), img), dim=1)
                    else:
                        recon_input_to_pool = recon.detach().cpu()

                    pred_D_recon = model_dict['D'](recon_pool.query(recon_input_to_pool))
                    loss_D_recon = criterion_dict['GAN'](pred_D_recon, False)

                    loss_D_fake_total += args.lambda_GAN_recon * loss_D_recon
                    loss_D_fake_total_weights += args.lambda_GAN_recon
                    current_losses['D_recon'] = loss_D_recon.item()

                if not args.single_GAN_recon_only:
                    assert args.lambda_pair_GAN == 0
                    if args.train_M:
                        all_M = torch.cat((fake.detach().cpu(),
                                           rand_fake.detach().cpu(),
                                           ), 0)
                        pred_D_M = model_dict['D'](fake_pool.query(all_M))
                        loss_D_M = criterion_dict['GAN'](pred_D_M, False)

                        loss_D_fake_total += args.lambda_GAN_M * loss_D_M
                        loss_D_fake_total_weights += args.lambda_GAN_M
                        current_losses['D_M'] = loss_D_M.item()

                    if args.lambda_GAN_WR:
                        pred_D_WR = model_dict['D'](WR_pool.query(rand_recon.detach().cpu()))
                        loss_D_WR = criterion_dict['GAN'](pred_D_WR, False)

                        loss_D_fake_total += args.lambda_GAN_WR * loss_D_WR
                        loss_D_fake_total_weights += args.lambda_GAN_WR
                        current_losses['D_WR'] = loss_D_WR.item()


                loss_D_fake = loss_D_fake_total / loss_D_fake_total_weights
                loss_D += args.lambda_GAN * (loss_D_fake + loss_D_real) * 0.5

                current_losses.update({
                    'D_real': loss_D_real.item(),
                    'D_fake': loss_D_fake.item()
                })


            if args.lambda_pair_GAN > 0:
                loss_pair_fake_total = 0
                loss_pair_real_total = 0
                loss_pair_cnt = 0.
                if args.train_M:
                    pred_pair_real1 = model_dict['pair_D'](torch.cat((img.cuda(), fake.detach()), 1))
                    pred_pair_real2 = model_dict['pair_D'](torch.cat((img.cuda(), rand_fake.detach()), 1))

                    all_fake_pair = torch.cat((torch.cat((fake.detach().cpu(), img), 1),
                                               torch.cat((rand_fake.detach().cpu(), img), 1),
                                               ), 0)
                    pred_pair_fake = model_dict['pair_D'](fake_pair_pool.query(all_fake_pair))

                    loss_pair_M_real = (criterion_dict['GAN'](pred_pair_real1, True) + criterion_dict['GAN'](pred_pair_real2, True)) / 2.
                    loss_pair_M_fake = criterion_dict['GAN'](pred_pair_fake, False)

                    loss_pair_real_total += loss_pair_M_real
                    loss_pair_fake_total += loss_pair_M_fake
                    loss_pair_cnt += 1

                pred_pair_WR_real = model_dict['pair_D'](torch.cat((img.cuda(), rand_recon.detach()), 1))
                pred_pair_WR_fake = model_dict['pair_D'](WR_pool.query(torch.cat((rand_recon.detach().cpu(), img), 1)))

                loss_pair_WR_real = criterion_dict['GAN'](pred_pair_WR_real, True)
                loss_pair_WR_fake = criterion_dict['GAN'](pred_pair_WR_fake, False)

                loss_pair_real_total += args.multiple_pair_WR_GAN * loss_pair_WR_real
                loss_pair_fake_total += args.multiple_pair_WR_GAN * loss_pair_WR_fake
                loss_pair_cnt += args.multiple_pair_WR_GAN

                loss_pair_D_real = loss_pair_real_total / loss_pair_cnt  # (loss_pair_M_real + args.multiple_pair_WR_GAN * loss_pair_WR_real) / (1. + args.multiple_pair_WR_GAN)
                loss_pair_D_fake = loss_pair_fake_total / loss_pair_cnt #(loss_pair_M_fake + args.multiple_pair_WR_GAN * loss_pair_WR_fake) / (1. + args.multiple_pair_WR_GAN)

                current_losses.update({
                    'pair_D_fake': loss_pair_D_fake.item(),
                    'pair_D_real': loss_pair_D_real.item()
                })
                loss_D += args.lambda_pair_GAN * (loss_pair_D_fake + loss_pair_D_real) * 0.5

            current_losses['D'] = loss_D.item()
            # D backward and optimizer steps
            loss_D.backward()

            if args.gan_mode == 'wgangp':
                real_to_wgangp = torch.cat((img, img), 0).to(args.device)
                if np.random.rand() > 0.5:
                    fake_selected = fake.detach()
                else:
                    fake_selected = rand_fake.detach()
                fake_to_wgangp = torch.cat((fake_selected, rand_recon.detach()), 0)
                loss_gp, gradients = models.cal_gradient_penalty(model_dict['D'], real_to_wgangp, fake_to_wgangp, args.device)
                # print('gradeints abs/l2 mean:', gradients[0], gradients[1])
                loss_gp *= args.lambda_GAN
                # print('loss_gp', loss_gp.item())
                loss_gp.backward()

            optimizer_dict['D'].step()

        # -------------------- G PART --------------------
        # init
        set_requires_grad(model_dict['D_nets'], False)
        set_requires_grad(model_dict['G_nets'], True)
        optimizer_dict['G'].zero_grad()

        loss_G = 0
        # GAN loss
        if args.lambda_GAN > 0:
            loss_G_GAN_total = 0.
            loss_G_GAN_total_weights = 0.

            # recon
            if args.lambda_GAN_recon:
                if args.recon_pair_GAN:
                    recon_input_G = torch.cat((recon, img.cuda()), dim=1)
                else:
                    recon_input_G = recon
                pred_G_recon = model_dict['D'](recon_input_G)
                loss_G_recon = criterion_dict['GAN'](pred_G_recon, True)

                loss_G_GAN_total += args.lambda_GAN_recon * loss_G_recon
                loss_G_GAN_total_weights += args.lambda_GAN_recon
                current_losses['G_recon'] = loss_G_recon.item()

            if not args.single_GAN_recon_only:
                if args.train_M:
                    pred_G_fake = model_dict['D'](fake)
                    pred_G_rand_fake = model_dict['D'](rand_fake)

                    loss_G_fake = criterion_dict['GAN'](pred_G_fake, True)
                    loss_G_rand_fake = criterion_dict['GAN'](pred_G_rand_fake, True)

                    loss_G_GAN_total += args.lambda_GAN_M * 0.5 * (loss_G_fake + loss_G_rand_fake)
                    loss_G_GAN_total_weights += args.lambda_GAN_M

                    current_losses['G_M'] = 0.5 * (loss_G_fake.item() + loss_G_rand_fake.item())

                pred_G_WR = model_dict['D'](rand_recon)
                loss_G_WR = criterion_dict['GAN'](pred_G_WR, True)
                current_losses['G_WR'] = loss_G_WR.item()

                loss_G_GAN_total += args.lambda_GAN_WR * loss_G_WR
                loss_G_GAN_total_weights += args.lambda_GAN_WR

            loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights
            loss_G += args.lambda_GAN * loss_G_GAN

            current_losses.update({'G_GAN': loss_G_GAN.item(),
                                   })


        if args.lambda_pair_GAN > 0:
            loss_pair_G_total = 0
            cnt_pair_G = 0.

            if args.train_M:
                pred_pair_fake1_G = model_dict['pair_D'](torch.cat((fake, img.cuda()), 1))
                pred_pair_fake2_G = model_dict['pair_D'](torch.cat((rand_fake, img.cuda()), 1))

                loss_pair_M_G = (criterion_dict['GAN'](pred_pair_fake1_G, True)
                               + criterion_dict['GAN'](pred_pair_fake2_G, True)) / 2.

                loss_pair_G_total += loss_pair_M_G
                cnt_pair_G += 1.

            pred_pair_fake3_G = model_dict['pair_D'](torch.cat((rand_recon, img.cuda()), 1))
            loss_pair_WR_G = criterion_dict['GAN'](pred_pair_fake3_G, True)

            loss_pair_G_total += args.multiple_pair_WR_GAN * loss_pair_WR_G
            cnt_pair_G += args.multiple_pair_WR_GAN

            loss_pair_G_avg = loss_pair_G_total / cnt_pair_G

            loss_G += args.lambda_pair_GAN * loss_pair_G_avg
            current_losses['pair_G'] = loss_pair_G_avg.item()

        # infoGAN loss
        def infoGAN_input(img1, img2):
            if args.use_minus_Q:
                return img2 - img1
            else:
                return torch.cat((img1, img2), 1)

        if args.lambda_dis > 0:
            infogan_acc = 0
            infogan_inv_acc = 0
            infogan_rand_acc = 0
            infogan_recon_rand_acc = 0

            dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), fake))
            loss_G_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = dis_logits[dis_idx].max(dim=1)[1]
                b = dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_acc += acc.item()
                loss_G_dis += criterion_dict['DIS'](dis_logits[dis_idx], dis_target[:, dis_idx])
            infogan_acc = infogan_acc / float(args.passwd_length // 4)

            inv_dis_logits = model_dict['Q'](infoGAN_input(fake, recon))
            loss_G_inv_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = inv_dis_logits[dis_idx].max(dim=1)[1]
                b = inv_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_inv_acc += acc.item()
                loss_G_inv_dis += criterion_dict['DIS'](inv_dis_logits[dis_idx], inv_dis_target[:, dis_idx])
            infogan_inv_acc = infogan_inv_acc / float(args.passwd_length // 4)

            rand_dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), rand_fake))
            loss_G_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = rand_dis_logits[dis_idx].max(dim=1)[1]
                b = rand_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_rand_acc += acc.item()
                loss_G_rand_dis += criterion_dict['DIS'](rand_dis_logits[dis_idx], rand_dis_target[:, dis_idx])
            infogan_rand_acc = infogan_rand_acc / float(args.passwd_length // 4)

            recon_rand_dis_logits = model_dict['Q'](infoGAN_input(fake, rand_recon))
            loss_G_recon_rand_dis = 0
            for dis_idx in range(args.passwd_length // 4):
                a = recon_rand_dis_logits[dis_idx].max(dim=1)[1]
                b = another_rand_dis_target[:, dis_idx]
                acc = torch.eq(a, b).type(torch.float).mean()
                infogan_recon_rand_acc += acc.item()
                loss_G_recon_rand_dis += criterion_dict['DIS'](recon_rand_dis_logits[dis_idx], another_rand_dis_target[:, dis_idx])
            infogan_recon_rand_acc = infogan_recon_rand_acc / float(args.passwd_length // 4)

            # current_losses.update({'G_dis': loss_G_dis.item(),
            #                        'G_inv_dis': loss_G_inv_dis.item(),
            #                        'G_dis_acc': infogan_acc,
            #                        'G_inv_dis_acc': infogan_inv_acc,
            #                        'G_rand_dis': loss_G_rand_dis.item(),
            #                        'G_recon_rand_dis': loss_G_recon_rand_dis.item(),
            #                        'G_rand_dis_acc': infogan_rand_acc,
            #                        'G_recon_rand_dis_acc': infogan_recon_rand_acc
            #                        })
            loss_dis = (loss_G_dis + loss_G_inv_dis + loss_G_rand_dis + loss_G_recon_rand_dis)
            dis_acc = (infogan_acc + infogan_inv_acc + infogan_rand_acc + infogan_recon_rand_acc) / 4.
            loss_G += args.lambda_dis * loss_dis
            current_losses.update({
                'dis': loss_dis.item(),
                'dis_acc': dis_acc
            })

        # FR loss, netFR must not be fixed
        loss_G_FR_total = 0
        cnt_G_FR = 0.

        if args.train_M:
            id_fake_G, fake_feat = model_dict['FR'](fake_aligned)
            loss_G_FR = -criterion_dict['FR'](id_fake_G, label.to(args.device))
            # current_losses['G_FR'] = loss_G_FR.item()

            id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned)
            loss_G_FR_rand = -criterion_dict['FR'](id_rand_fake_G, label.to(args.device))
            # current_losses['G_FR_rand'] = loss_G_FR_rand.item()

            loss_G_FR_total += loss_G_FR + loss_G_FR_rand
            cnt_G_FR += 2

        if args.feature_loss == 'cos':
            FR_cos_sim_target = torch.empty(size=(batch_size, 1), dtype=torch.float32, device=args.device)
            FR_cos_sim_target.fill_(-1.)

        if args.lambda_Feat:
            if args.feature_loss == 'cos':
                loss_G_feat = criterion_dict['Feat'](fake_feat, rand_fake_feat, target=FR_cos_sim_target)
            else:
                loss_G_feat = -criterion_dict['Feat'](fake_feat, rand_fake_feat)
            current_losses['G_feat'] = loss_G_feat.item()
            loss_G += args.lambda_Feat * loss_G_feat


        if args.lambda_G_recon:
            id_recon_G, recon_feat = model_dict['FR'](recon_aligned)
            if args.lambda_FR_WR:
                id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned)

            if args.lambda_recon_Feat:
                if args.feature_loss == 'cos':
                    loss_G_recon_feat = criterion_dict['Feat'](recon_feat, rand_recon_feat, target=FR_cos_sim_target)
                else:
                    loss_G_recon_feat = -criterion_dict['Feat'](recon_feat, rand_recon_feat)
                current_losses['G_recon_feat'] = loss_G_recon_feat.item()
                loss_G += args.lambda_recon_Feat * loss_G_recon_feat

            if args.lambda_false_recon_diff:
                if args.feature_loss == 'cos':
                    loss_G_false_recon_feat =criterion_dict['Feat'](fake_feat, rand_recon_feat, target=FR_cos_sim_target)
                else:
                    loss_G_false_recon_feat =-criterion_dict['Feat'](fake_feat, rand_recon_feat)
                current_losses['G_false_recon_feat'] = loss_G_false_recon_feat.item()
                loss_G += args.lambda_false_recon_diff * loss_G_false_recon_feat

            if args.recon_FR:
                loss_G_FR_recon = criterion_dict['FR'](id_recon_G, label.to(args.device))
                # current_losses['G_FR_recon'] = loss_G_FR_recon.item()
                if args.lambda_FR_WR:
                    loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G, label.to(args.device))
                else:
                    loss_G_FR_rand_recon = 0.
                # current_losses['G_FR_rand_recon'] = loss_G_FR_rand_recon.item()
                loss_G_FR_total += loss_G_FR_recon + args.lambda_FR_WR * loss_G_FR_rand_recon
                cnt_G_FR += 1. + args.lambda_FR_WR

        loss_G_FR_avg = loss_G_FR_total / cnt_G_FR

        loss_G += args.lambda_FR * loss_G_FR_avg
        current_losses['G_FR'] = loss_G_FR_avg.item()


        # loss_L1 = 0
        # cnt_loss_L1 = 0
        if args.lambda_L1 > 0:
            loss_G_L1 = criterion_dict['L1'](fake, img.cuda())
            current_losses['L1'] = loss_G_L1.item()
            # loss_L1 += loss_G_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_L1 * loss_G_L1

        if args.lambda_rand_L1 > 0:
            loss_G_rand_L1 = criterion_dict['L1'](rand_fake, img.cuda())
            current_losses['rand_L1'] = loss_G_rand_L1.item()
            # loss_L1 += loss_G_rand_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_rand_L1 * loss_G_rand_L1

        if args.lambda_rand_recon_L1 > 0:
            loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img.cuda())
            current_losses['wrong_recon_L1'] = loss_G_rand_recon_L1.item()
            # loss_L1 += loss_G_rand_recon_L1.item()
            # cnt_loss_L1 += 1
            loss_G += args.lambda_rand_recon_L1 * loss_G_rand_recon_L1

        # current_losses['L1'] = loss_L1 / float(cnt_loss_L1)

        if args.lambda_G_recon > 0:
            loss_G_recon = criterion_dict['L1'](recon, img.cuda())
            loss_G += args.lambda_G_recon * loss_G_recon
            current_losses['recon'] = loss_G_recon.item()

        if args.lambda_G_rand_recon > 0:
            if args.use_minus_one:
                inv_rand_z = rand_z * -1
            else:
                inv_rand_z = 1.0 - rand_z
            rand_fake_recon = model_dict['G'](rand_fake, inv_rand_z)
            loss_G_rand_recon = criterion_dict['L1'](rand_fake_recon, img.cuda())
            loss_G += args.lambda_G_rand_recon * loss_G_rand_recon
            current_losses['another_recon'] = loss_G_rand_recon.item()

        current_losses['G'] = loss_G.item()

        # G backward and optimizer steps
        loss_G.backward()
        optimizer_dict['G'].step()

        # -------------------- LOGGING PART --------------------
        if i % args.print_loss_freq == 0:
            t = (time.time() - iter_start_time) / batch_size
            visualizer.print_current_losses(epoch, i, current_losses, t, t_data)
            if args.display_id > 0 and i % args.plot_loss_freq == 0:
                visualizer.plot_current_losses(epoch, float(i) / len(train_loader), args, current_losses)
            if args.print_gradient:
                for net_name, net in model_dict.items():
                    # if net_name != 'Q':
                    #     continue
                    if isinstance(net, list):
                        continue
                    print(('================ NET %s ================' % net_name))
                    for name, param in net.named_parameters():
                        print_param_info(name, param, print_std=True)

        if i % args.visdom_visual_freq == 0:
            save_result = i % args.update_html_freq == 0

            current_visuals = OrderedDict()
            current_visuals['real'] = img.detach()
            current_visuals['fake'] = fake.detach()
            current_visuals['rand_fake'] = rand_fake.detach()
            if args.lambda_G_recon:
                current_visuals['recon'] = recon.detach()
                current_visuals['rand_recon'] = rand_recon.detach()
            if args.lambda_G_rand_recon > 0:
                current_visuals['rand_fake_recon'] = rand_fake_recon.detach()
            current_visuals['real_aligned'] = real_aligned.detach()
            current_visuals['fake_aligned'] = fake_aligned.detach()
            current_visuals['rand_fake_aligned'] = rand_fake_aligned.detach()
            if args.lambda_G_recon:
                current_visuals['recon_aligned'] = recon_aligned.detach()
                current_visuals['rand_recon_aligned'] = rand_recon_aligned.detach()

            try:
                with time_limit(60):
                    visualizer.display_current_results(current_visuals, epoch, save_result, args)
            except TimeoutException:
                visualizer.logger.log('TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1'.format(epoch, i))
                args.display_id = -1

        if (i + 1) % args.save_iter_freq == 0:
            save_model(epoch, model_dict, optimizer_dict, args, iter=i)
            if args.display_id > 0:
                visualizer.vis.save([args.name])
                visualizer.overview_vis.save(['overview'])

        if (i + 1) % args.html_iter_freq == 0:
            test(test_loader, model_dict, criterion_dict, visualizer, epoch, args, fixed_z, fixed_rand_z, i)

        if (i + 1) % args.print_loss_freq == 0:
            iter_data_time = time.time()