Ejemplo n.º 1
0
def fit_validate(exp_params, k, data_path, write_path, others=None, custom_tag=''):
    """Fit model and compute metrics on train and validation set. Intended for hyperparameter search.

    Only logs final metrics and scatter plot of final embedding.

    Args:
        exp_params(dict): Parameter dict. Should at least have keys model_name, dataset_name & random_state. Other
        keys are assumed to be model parameters.
        k(int): Fold identifier.
        data_path(str): Data directory.
        write_path(str): Where to write temp files.
        others(dict): Other things to log to Comet experiment.
        custom_tag(str): Custom tag for comet experiment.

    """
    # Comet experiment
    exp = Experiment(parse_args=False)
    exp.disable_mp()
    custom_tag += '_validate'
    exp.add_tag(custom_tag)
    exp.log_parameters(exp_params)

    if others is not None:
        exp.log_others(others)

    # Parse experiment parameters
    model_name, dataset_name, random_state, model_params = parse_params(exp_params)

    # Fetch and split dataset.
    data_train = getattr(grae.data, dataset_name)(split='train', random_state=random_state, data_path=data_path)
    data_train, data_val = data_train.validation_split(random_state=FOLD_SEEDS[k])

    # Model
    m = getattr(grae.models, model_name)(random_state=FOLD_SEEDS[k], **model_params)
    m.write_path = write_path
    m.data_val = data_val

    with exp.train():
        m.fit(data_train)

        # Log plot
        m.comet_exp = exp
        m.plot(data_train, data_val, title=f'{model_name} : {dataset_name}')

        # Probe embedding
        prober = EmbeddingProber()
        prober.fit(model=m, dataset=data_train, mse_only=True)
        train_z, train_metrics = prober.score(data_train, is_train=True)

        # Log train metrics
        exp.log_metrics(train_metrics)

    with exp.validate():
        val_z, val_metrics = prober.score(data_val)

        # Log train metrics
        exp.log_metrics(val_metrics)

    # Log marker to mark successful experiment
    exp.log_other('success', 1)
Ejemplo n.º 2
0
def log_metrics(metrics: dict, comet_logger: Experiment, epoch: int,
                context_val: bool):
    if context_val:
        with comet_logger.validate():
            comet_logger.log_metrics(metrics, epoch=epoch)
    else:
        with comet_logger.train():
            comet_logger.log_metrics(metrics, epoch=epoch)
Ejemplo n.º 3
0
def log_simclr_images(img1: Tensor, img2: Tensor, context_val: bool,
                      comet_logger: Experiment):

    if context_val:
        with comet_logger.validate():
            plot_simclr_images(img1.data[0].cpu(), img2.data[0].cpu(),
                               comet_logger)
    else:
        with comet_logger.train():
            plot_simclr_images(img1.data[0].cpu(), img2.data[0].cpu(),
                               comet_logger)
Ejemplo n.º 4
0
def log_hybrid2_images(
    img1: Tensor,
    img2: Tensor,
    params: Dict[str, Tensor],
    context_val: bool,
    comet_logger: Experiment,
):
    params = {k: v.data[0].cpu() for k, v in params.items()}
    if context_val:
        with comet_logger.validate():
            plot_hybrid2_images(img1.data[0].cpu(), img2.data[0].cpu(), params,
                                comet_logger)
    else:
        with comet_logger.train():
            plot_hybrid2_images(img1.data[0].cpu(), img2.data[0].cpu(), params,
                                comet_logger)
Ejemplo n.º 5
0
class Logger(object):
    def __init__(self, dataset_name, model_name):
        self.model_name = model_name
        self.project_name = "%s-%s" % (dataset_name, self.model_name)
        self.logdir = os.path.join(hp.logdir, self.project_name)
        self.writer = SummaryWriter(log_dir=self.logdir)

        self.experiment = None  # Experiment(api_key="luY5eUQDsBynS168WxJiRPJmJ", project_name=self.project_name, log_code=False)
        if hp.comet_ml_api_key is not None:
            self.experiment = Experiment(api_key=hp.comet_ml_api_key,
                                         project_name=self.project_name,
                                         log_code=False)
            self.experiment.log_multiple_params(
                dict((name, getattr(hp, name)) for name in dir(hp)
                     if not name.startswith('__')))

    def log_step(self, phase, step, loss_dict, image_dict):
        if phase == 'train':
            if step % 50 == 0:
                if self.experiment is not None:
                    with self.experiment.train():
                        self.experiment.log_multiple_metrics(loss_dict,
                                                             step=step)

                # self.writer.add_scalar('lr', get_lr(), step)
                # self.writer.add_scalar('%s-step/loss' % phase, loss, step)
                for key in sorted(loss_dict):
                    self.writer.add_scalar('%s-step/%s' % (phase, key),
                                           loss_dict[key], step)

            if step % 1000 == 0:
                for key in sorted(image_dict):
                    self.writer.add_image('%s/%s' % (self.model_name, key),
                                          image_dict[key], step)

    def log_epoch(self, phase, step, loss_dict):
        for key in sorted(loss_dict):
            self.writer.add_scalar('%s/%s' % (phase, key), loss_dict[key],
                                   step)

        if phase == 'valid':
            if self.experiment is not None:
                with self.experiment.validate():
                    self.experiment.log_multiple_metrics(loss_dict, step=step)
Ejemplo n.º 6
0
def log_pairwise_images(
    img1: Tensor,
    img2: Tensor,
    gt_pred: Dict[str, Tensor],
    context_val: bool,
    comet_logger: Experiment,
):
    gt_pred = {
        k: [v[0].data[0].cpu().numpy(), v[1].data[0].cpu().numpy()]
        for k, v in gt_pred.items()
    }
    if context_val:
        with comet_logger.validate():
            plot_pairwise_images(img1.data[0].cpu(), img2.data[0].cpu(),
                                 gt_pred, comet_logger)
    else:
        with comet_logger.train():
            plot_pairwise_images(img1.data[0].cpu(), img2.data[0].cpu(),
                                 gt_pred, comet_logger)
Ejemplo n.º 7
0
def log_image(
    prediction: Tensor,
    y: Tensor,
    x: Tensor,
    gpu: bool,
    context_val: bool,
    comet_logger: Experiment,
):
    if gpu:
        pred_label = prediction.data[0].cpu().numpy()
        true_label = y.data[0].cpu().detach().numpy()
    else:
        pred_label = prediction[0].detach().numpy()
        true_label = y[0].detach().numpy()
    if context_val:
        with comet_logger.validate():
            plot_truth_vs_prediction(pred_label, true_label, x.data[0].cpu(),
                                     comet_logger)
    else:
        with comet_logger.train():
            plot_truth_vs_prediction(pred_label, true_label, x.data[0].cpu(),
                                     comet_logger)
Ejemplo n.º 8
0
class Experiment:
    """
        A helper class to facilitate the training and validation procedure of the GoTurnRemix model

        Parameters
        ----------
        learning_rate: float
            Learning rate to train the model. The optimizer is SGD and the loss is L1 Loss
        image_size: int
            The size of the input image. This has to be fixed before the data is created
        data_path: Path
            Path to the data folder. If the folder name includes "pickle", then the data saved as pickles are loaded
        augment: bool
            Perform augmentation on the images before training
        logs_path: Path
            Path to save the validation predictions at the end of each epoch
        models_path: Path
            Path to save the model state at the end of each epoch
        save_name: str
            Name of the folder in which the logs and models are saved. If not provided, the current datetime is used
    """
    def __init__(self,
                 learning_rate: float,
                 image_size: int,
                 data_path: Path,
                 augment: bool = True,
                 logs_path: Path = None,
                 models_path: Path = None,
                 save_name: str = None,
                 comet_api: str = None):
        self.image_size = image_size
        self.logs_path = logs_path
        self.models_path = models_path
        self.model = GoTurnRemix()
        self.model.cuda()
        self.criterion = torch.nn.L1Loss()
        self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                self.model.parameters()),
                                         lr=learning_rate)
        self.model_name = str(datetime.datetime.now()).split('.')[0].replace(
            ':', '-').replace(' ', '-')
        self.model_name = save_name if save_name else self.model_name
        self.augment = augment
        self.data = Data(data_path,
                         target_size=self.image_size,
                         transforms=augment)
        self.comet = None
        if comet_api:
            self.comet = Comet(api_key=comet_api)
            self.comet.log_parameter('learning_rate', learning_rate)
            self.comet.log_parameter('image_size', image_size)
            self.comet.log_parameter('augment', augment)

    def __train_step__(self, data):
        """
        Performs one step of the training procedure

        Parameters
        ----------
        data
            data obtained from @Data.__getitem__

        Returns
        -------
           Loss at the end of training step
        """
        if self.comet:
            self.comet.train()
        previous_cropped, current_cropped, bbox, scale, crop = data
        previous_cropped = torch.div(previous_cropped, 255).float().cuda()
        current_cropped = torch.div(current_cropped, 255).float().cuda()
        previous_cropped = torch.autograd.Variable(previous_cropped,
                                                   requires_grad=True)
        current_cropped = torch.autograd.Variable(current_cropped,
                                                  requires_grad=True)
        bbox = bbox.requires_grad_(True).float().cuda()
        self.optimizer.zero_grad()
        preds = self.model(previous_cropped, current_cropped)

        del previous_cropped
        del current_cropped
        gc.collect()

        loss = self.criterion(preds, bbox)
        if self.comet:
            self.comet.log_metric('loss', loss)
        loss.backward()
        self.optimizer.step()
        return loss

    def __test__(self):
        """
        Test tracking of the model

        Returns
        -------
            Test loss and test predictions
        """
        # Set model to evaluation mode
        if self.comet:
            self.comet.test()
        self.model.eval()
        test_preds = []
        test_loss = []
        video_frames = self.data.video_frames[-1]
        video_annotations = self.data.video_annotations[-1]
        p_a = video_annotations[0]
        p_f = video_frames[0]
        test_preds.append(p_a)

        for i in tqdm(range(1, len(video_annotations)), desc='Validating'):
            c_a = video_annotations[i]
            c_f = video_frames[i]
            p_c, c_c, bbox, scale, crop = self.data.make_crops(
                p_f, c_f, p_a, c_a)
            p_c = torch.div(torch.from_numpy(p_c),
                            255).unsqueeze(0).float().cuda()
            c_c = torch.div(torch.from_numpy(c_c),
                            255).unsqueeze(0).float().cuda()
            bbox = torch.tensor(bbox, requires_grad=False).float().cuda()
            preds = self.model(p_c, c_c)

            del p_c
            del c_c
            gc.collect()

            loss = torch.nn.functional.l1_loss(preds, bbox)
            if self.comet:
                self.comet.log_metric('val_loss', loss)
            test_loss.append(loss.item())
            preds = self.data.get_bbox(preds.cpu().detach().numpy()[0],
                                       self.image_size, scale, crop)
            test_preds.append(preds)
            p_a = preds
            p_f = c_f
        return test_loss, test_preds

    def __validate__(self):
        """
        Performs validation on the model

        Returns
        -------
            Validation loss and validation predictions
        """
        # Set model to evaluation mode
        if self.comet:
            self.comet.validate()
        self.model.eval()
        validation_preds = []
        validation_loss = []
        video_frames = self.data.video_frames[-1]
        video_annotations = self.data.video_annotations[-1]
        p_a = video_annotations[0]
        p_f = video_frames[0]
        validation_preds.append(p_a)

        for i in tqdm(range(1, len(video_annotations)), desc='Validating'):
            c_a = video_annotations[i]
            c_f = video_frames[i]
            p_c, c_c, bbox, scale, crop = self.data.make_crops(
                p_f, c_f, p_a, c_a)
            p_c = torch.div(torch.from_numpy(p_c),
                            255).unsqueeze(0).float().cuda()
            c_c = torch.div(torch.from_numpy(c_c),
                            255).unsqueeze(0).float().cuda()
            bbox = torch.tensor(bbox, requires_grad=False).float().cuda()
            preds = self.model(p_c, c_c)

            del p_c
            del c_c
            gc.collect()

            loss = torch.nn.functional.l1_loss(preds, bbox)
            if self.comet:
                self.comet.log_metric('val_loss', loss)
            validation_loss.append(loss.item())
            preds = self.data.get_bbox(preds.cpu().detach().numpy()[0],
                                       self.image_size, scale, crop)
            validation_preds.append(preds)
            p_a = c_a
            p_f = c_f
        return validation_loss, validation_preds

    def train(self,
              epochs: int,
              batch_size: int,
              validate: bool = True,
              test: bool = True):
        """
        Trains the model for @epochs number of epochs

        Parameters
        ----------
        epochs: int
            Number of epochs to train the model
        batch_size: int
            The size of each batch when training the model
        validate: bool, default=True
            If True, validation occurs at the end of each epoch
            The results are saved in @logs_path and models are saved in @models_path
        test: bool, default=True
            If True, the model is tested for tracking at the end of the training procedure
            The results are saved in @logs_path

        Returns
        -------
            list: List containing the training loss at the end of each epoch
        """
        if self.comet:
            self.comet.log_parameter('epochs', epochs)
            self.comet.log_parameter('batch_size', batch_size)
        loss_per_epoch = []
        preds_per_epoch = []
        # Set the model to training mode
        self.model.train()
        # Create a DataLoader to feed data to the model
        dataloader = torch.utils.data.DataLoader(dataset=self.data,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        # Run for @epochs number of epochs
        for epoch in range(epochs):
            if self.comet:
                self.comet.log_metric('epoch', epoch)
            running_loss = []
            for step, data in enumerate(
                    tqdm(dataloader,
                         total=int(len(self.data) / batch_size),
                         desc='Epoch {}'.format(epoch))):
                loss = self.__train_step__(data)
                running_loss.append(loss.item())
            training_loss = sum(running_loss) / len(running_loss)
            if self.comet:
                self.comet.log_metric('mean_train_loss', training_loss)
            loss_per_epoch.append(sum(running_loss) / len(running_loss))
            if validate:
                validation_loss, validation_preds = self.__validate__()
                if self.comet:
                    self.comet.log_metric('mean_validation_loss',
                                          validation_loss)
                preds_per_epoch.append(validation_preds)
                print('Validation loss: {}'.format(
                    sum(validation_loss) / len(validation_loss)))
            # Save the model at this stage
            if self.models_path:
                (self.models_path / self.model_name).mkdir(exist_ok=True)
                torch.save(self.model, (self.models_path / self.model_name /
                                        'epoch_{}'.format(epoch)).resolve())
            print('Training Loss: {}'.format(training_loss))
        # Save the validation frames, ground truths and predictions at this stage
        if self.logs_path:
            (self.logs_path / self.model_name).mkdir(exist_ok=True)
            save = {
                'frames': self.data.video_frames[-1],
                'truth': self.data.video_annotations[-1],
                'preds': preds_per_epoch
            }
            np.save(
                str((self.logs_path / self.model_name /
                     'preds_per_epoch.npy').resolve()), save)
        # Test the model and save the results
        if test:
            test_loss, test_preds = self.__test__()
            if self.logs_path:
                (self.logs_path / self.model_name).mkdir(exist_ok=True)
                save = {
                    'frames': self.data.video_frames[-1],
                    'truth': self.data.video_annotations[-1],
                    'preds': test_preds,
                    'loss': test_loss
                }
                np.save(
                    str((self.logs_path / self.model_name /
                         'test_preds.npy').resolve()), save)
        return loss_per_epoch
Ejemplo n.º 9
0
def run(args, train, sparse_evidences, claims_dict):
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    DATA_SAMPLING = args.data_sampling
    NUM_EPOCHS = args.epochs
    MODEL = args.model
    RANDOMIZE = args.no_randomize
    PRINT = args.print

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    logger = Logger('./logs/{}'.format(time.localtime()))

    if MODEL:
        print("Loading pretrained model...")
        model = torch.load(MODEL)
        model.load_state_dict(torch.load(MODEL).state_dict())
    else:
        model = cdssm.CDSSM()
        model = model.cuda()
        model = model.to(device)

    # model = cdssm.CDSSM()
    # model = model.cuda()
    # model = model.to(device)

    if torch.cuda.device_count() > 0:
        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
        model = nn.DataParallel(model)

    print("Created model with {:,} parameters.".format(
        putils.count_parameters(model)))

    # if MODEL:
    # print("TEMPORARY change to loading!")
    # model.load_state_dict(torch.load(MODEL).state_dict())

    print("Created dataset...")

    # use an 80/20 train/validate split!
    train_size = int(len(train) * 0.80)
    #test = int(len(train) * 0.5)
    train_dataset = pytorch_data_loader.WikiDataset(
        train[:train_size],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)
    val_dataset = pytorch_data_loader.WikiDataset(
        train[train_size:],
        claims_dict,
        data_sampling=DATA_SAMPLING,
        sparse_evidences=sparse_evidences,
        randomize=RANDOMIZE)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  num_workers=0,
                                  shuffle=True,
                                  collate_fn=pytorch_data_loader.PadCollate())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                num_workers=0,
                                shuffle=True,
                                collate_fn=pytorch_data_loader.PadCollate())

    # Loss and optimizer
    criterion = torch.nn.NLLLoss()
    # criterion = torch.nn.SoftMarginLoss()
    # if torch.cuda.device_count() > 0:
    # print("Let's parallelize the backward pass...")
    # criterion = DataParallelCriterion(criterion)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=LEARNING_RATE,
                                 weight_decay=1e-3)

    OUTPUT_FREQ = max(int((len(train_dataset) / BATCH_SIZE) * 0.02), 20)
    parameters = {
        "batch size": BATCH_SIZE,
        "epochs": NUM_EPOCHS,
        "learning rate": LEARNING_RATE,
        "optimizer": optimizer.__class__.__name__,
        "loss": criterion.__class__.__name__,
        "training size": train_size,
        "data sampling rate": DATA_SAMPLING,
        "data": args.data,
        "sparse_evidences": args.sparse_evidences,
        "randomize": RANDOMIZE,
        "model": MODEL
    }
    experiment = Experiment(api_key="YLsW4AvRTYGxzdDqlWRGCOhee",
                            project_name="clsm",
                            workspace="moinnadeem")
    experiment.add_tag("train")
    experiment.log_asset("cdssm.py")
    experiment.log_dataset_info(name=args.data)
    experiment.log_parameters(parameters)

    model_checkpoint_dir = "models/saved_model"
    for key, value in parameters.items():
        if type(value) == str:
            value = value.replace("/", "-")
        if key != "model":
            model_checkpoint_dir += "_{}-{}".format(key.replace(" ", "_"),
                                                    value)

    print("Training...")
    beginning_time = time.time()
    best_loss = torch.tensor(float("inf"),
                             dtype=torch.float)  # begin loss at infinity

    for epoch in range(NUM_EPOCHS):
        beginning_time = time.time()
        mean_train_acc = 0.0
        train_running_loss = 0.0
        train_running_accuracy = 0.0
        model.train()
        experiment.log_current_epoch(epoch)

        with experiment.train():
            for train_batch_num, inputs in enumerate(train_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()
                #claims = claims.to(device).float()
                #evidences = evidences.to(device).float()
                #labels = labels.to(device)

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y = y.unsqueeze(0)
                # y = y.unsqueeze(0)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()
                # y = y.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])
                # loss = criterion(y_pred, y)

                y = y.float()
                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                accuracy = (binary_y == binary_pred).to("cuda")
                accuracy = accuracy.float()
                accuracy = accuracy.mean()
                train_running_accuracy += accuracy.item()
                mean_train_acc += accuracy.item()
                train_running_loss += loss.item()

                if PRINT:
                    for idx in range(len(y)):
                        print(
                            "Claim: {}, Evidence: {}, Prediction: {}, Label: {}"
                            .format(claims_text[0], evidences_text[idx],
                                    torch.exp(y_pred[idx]), y[idx]))

                if (train_batch_num %
                        OUTPUT_FREQ) == 0 and train_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    binary_y = torch.max(y, 1)[1]
                    binary_pred = torch.max(y_pred, 1)[1]
                    print(
                        "[{}:{}:{:3f}s] training loss: {}, training accuracy: {}, training recall: {}"
                        .format(
                            epoch, train_batch_num /
                            (len(train_dataset) / BATCH_SIZE), elapsed_time,
                            train_running_loss / OUTPUT_FREQ,
                            train_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {
                        'train_loss': train_running_loss / OUTPUT_FREQ,
                        'train_accuracy': train_running_accuracy / OUTPUT_FREQ
                    }

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=train_batch_num *
                                              (epoch + 1))
                        logger.scalar_summary(tag, value, train_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             train_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             train_batch_num + 1)

                    train_running_loss = 0.0
                    beginning_time = time.time()
                    train_running_accuracy = 0.0
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        print("Running validation...")
        model.eval()
        pred = []
        true = []
        avg_loss = 0.0
        val_running_accuracy = 0.0
        val_running_loss = 0.0
        beginning_time = time.time()
        with experiment.validate():
            for val_batch_num, val_inputs in enumerate(val_dataloader):
                claims_tensors, claims_text, evidences_tensors, evidences_text, labels = val_inputs

                claims_tensors = claims_tensors.cuda()
                evidences_tensors = evidences_tensors.cuda()
                labels = labels.cuda()

                y_pred = model(claims_tensors, evidences_tensors)

                y = (labels)
                # y_pred = parallel.gather(y_pred, 0)

                y_pred = y_pred.squeeze()

                loss = criterion(y_pred, torch.max(y, 1)[1])

                y = y.float()

                binary_y = torch.max(y, 1)[1]
                binary_pred = torch.max(y_pred, 1)[1]
                true.extend(binary_y.tolist())
                pred.extend(binary_pred.tolist())

                accuracy = (binary_y == binary_pred).to("cuda")

                accuracy = accuracy.float().mean()
                val_running_accuracy += accuracy.item()
                val_running_loss += loss.item()
                avg_loss += loss.item()

                if (val_batch_num % OUTPUT_FREQ) == 0 and val_batch_num > 0:
                    elapsed_time = time.time() - beginning_time
                    print(
                        "[{}:{}:{:3f}s] validation loss: {}, accuracy: {}, recall: {}"
                        .format(
                            epoch,
                            val_batch_num / (len(val_dataset) / BATCH_SIZE),
                            elapsed_time, val_running_loss / OUTPUT_FREQ,
                            val_running_accuracy / OUTPUT_FREQ,
                            recall_score(binary_y.cpu().detach().numpy(),
                                         binary_pred.cpu().detach().numpy())))

                    # 1. Log scalar values (scalar summary)
                    info = {'val_accuracy': val_running_accuracy / OUTPUT_FREQ}

                    for tag, value in info.items():
                        experiment.log_metric(tag,
                                              value,
                                              step=val_batch_num * (epoch + 1))
                        logger.scalar_summary(tag, value, val_batch_num + 1)

                    ## 2. Log values and gradients of the parameters (histogram summary)
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        logger.histo_summary(tag,
                                             value.detach().cpu().numpy(),
                                             val_batch_num + 1)
                        logger.histo_summary(tag + '/grad',
                                             value.grad.detach().cpu().numpy(),
                                             val_batch_num + 1)

                    val_running_accuracy = 0.0
                    val_running_loss = 0.0
                    beginning_time = time.time()

        # del loss
        # del accuracy
        # del claims_tensors
        # del claims_text
        # del evidences_tensors
        # del evidences_text
        # del labels
        # del y
        # del y_pred
        # torch.cuda.empty_cache()

        accuracy = accuracy_score(true, pred)
        print("[{}] mean accuracy: {}, mean loss: {}".format(
            epoch, accuracy, avg_loss / len(val_dataloader)))

        true = np.array(true).astype("int")
        pred = np.array(pred).astype("int")
        print(classification_report(true, pred))

        best_loss = torch.tensor(
            min(avg_loss / len(val_dataloader),
                best_loss.cpu().numpy()))
        is_best = bool((avg_loss / len(val_dataloader)) <= best_loss)

        putils.save_checkpoint(
            {
                "epoch": epoch,
                "model": model,
                "best_loss": best_loss
            },
            is_best,
            filename="{}_loss_{}".format(model_checkpoint_dir,
                                         best_loss.cpu().numpy()))
Ejemplo n.º 10
0
class Trainer():
    def __init__(self, log_dir, cfg):

        self.path = log_dir
        self.cfg = cfg

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(self.path, 'Model')
            self.log_dir = os.path.join(self.path, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.logfile = os.path.join(self.path, "logfile.log")
            sys.stdout = Logger(logfile=self.logfile)

        self.data_dir = cfg.DATASET.DATA_DIR
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.lr = cfg.TRAIN.LEARNING_RATE

        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        sample = cfg.SAMPLE
        self.dataset = []
        self.dataloader = []
        self.use_feats = cfg.model.use_feats
        eval_split = cfg.EVAL if cfg.EVAL else 'val'
        train_split = cfg.DATASET.train_split
        if cfg.DATASET.DATASET == 'clevr':
            clevr_collate_fn = collate_fn
            cogent = cfg.DATASET.COGENT
            if cogent:
                print(f'Using CoGenT {cogent.upper()}')

            if cfg.TRAIN.FLAG:
                self.dataset = ClevrDataset(data_dir=self.data_dir,
                                            split=train_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=clevr_collate_fn)

            self.dataset_val = ClevrDataset(data_dir=self.data_dir,
                                            split=eval_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             drop_last=False,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             collate_fn=clevr_collate_fn)

        elif cfg.DATASET.DATASET == 'gqa':
            if self.use_feats == 'spatial':
                gqa_collate_fn = collate_fn_gqa
            elif self.use_feats == 'objects':
                gqa_collate_fn = collate_fn_gqa_objs
            if cfg.TRAIN.FLAG:
                self.dataset = GQADataset(data_dir=self.data_dir,
                                          split=train_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=gqa_collate_fn)

            self.dataset_val = GQADataset(data_dir=self.data_dir,
                                          split=eval_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             drop_last=False,
                                             collate_fn=gqa_collate_fn)

        # load model
        self.vocab = load_vocab(cfg)
        self.model, self.model_ema = mac.load_MAC(cfg, self.vocab)

        self.weight_moving_average(alpha=0)
        if cfg.TRAIN.RADAM:
            self.optimizer = RAdam(self.model.parameters(), lr=self.lr)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.start_epoch = 0
        if cfg.resume_model:
            location = 'cuda' if cfg.CUDA else 'cpu'
            state = torch.load(cfg.resume_model, map_location=location)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
            self.start_epoch = state['iter'] + 1
            state = torch.load(cfg.resume_model_ema, map_location=location)
            self.model_ema.load_state_dict(state['model'])

        if cfg.start_epoch is not None:
            self.start_epoch = cfg.start_epoch

        self.previous_best_acc = 0.0
        self.previous_best_epoch = 0
        self.previous_best_loss = 100
        self.previous_best_loss_epoch = 0

        self.total_epoch_loss = 0
        self.prior_epoch_loss = 10

        self.print_info()
        self.loss_fn = torch.nn.CrossEntropyLoss().cuda()

        self.comet_exp = Experiment(
            project_name=cfg.COMET_PROJECT_NAME,
            api_key=os.getenv('COMET_API_KEY'),
            workspace=os.getenv('COMET_WORKSPACE'),
            disabled=cfg.logcomet is False,
        )
        if cfg.logcomet:
            exp_name = cfg_to_exp_name(cfg)
            print(exp_name)
            self.comet_exp.set_name(exp_name)
            self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg))
            self.comet_exp.log_asset(self.logfile)
            self.comet_exp.log_asset_data(json.dumps(cfg, indent=4),
                                          file_name='cfg.json')
            self.comet_exp.set_model_graph(str(self.model))
            if cfg.cfg_file:
                self.comet_exp.log_asset(cfg.cfg_file)

        with open(os.path.join(self.path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=4)

    def print_info(self):
        print('Using config:')
        pprint.pprint(self.cfg)
        print("\n")

        pprint.pprint("Size of train dataset: {}".format(len(self.dataset)))
        # print("\n")
        pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val)))
        print("\n")

        print("Using MAC-Model:")
        pprint.pprint(self.model)
        print("\n")

    def weight_moving_average(self, alpha=0.999):
        for param1, param2 in zip(self.model_ema.parameters(),
                                  self.model.parameters()):
            param1.data *= alpha
            param1.data += (1.0 - alpha) * param2.data

    def set_mode(self, mode="train"):
        if mode == "train":
            self.model.train()
            self.model_ema.train()
        else:
            self.model.eval()
            self.model_ema.eval()

    def reduce_lr(self):
        epoch_loss = self.total_epoch_loss  # / float(len(self.dataset) // self.batch_size)
        lossDiff = self.prior_epoch_loss - epoch_loss
        if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \
            (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \
            (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)):
            self.lr *= 0.5
            print("Reduced learning rate to {}".format(self.lr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        self.prior_epoch_loss = epoch_loss
        self.total_epoch_loss = 0

    def save_models(self, iteration):
        save_model(self.model,
                   self.optimizer,
                   iteration,
                   self.model_dir,
                   model_name="model")
        save_model(self.model_ema,
                   None,
                   iteration,
                   self.model_dir,
                   model_name="model_ema")

    def train_epoch(self, epoch):
        cfg = self.cfg
        total_loss = 0.
        total_correct = 0
        total_samples = 0

        self.labeled_data = iter(self.dataloader)
        self.set_mode("train")

        dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20)

        for data in dataset:
            ######################################################
            # (1) Prepare training data
            ######################################################
            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()
            else:
                question = question
                image = image
                answer = answer.squeeze()

            ############################
            # (2) Train Model
            ############################
            self.optimizer.zero_grad()

            scores = self.model(image, question, question_len)
            loss = self.loss_fn(scores, answer)
            loss.backward()

            if self.cfg.TRAIN.CLIP_GRADS:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.TRAIN.CLIP)

            self.optimizer.step()
            self.weight_moving_average()

            ############################
            # (3) Log Progress
            ############################
            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            total_loss += loss.item() * answer.size(0)
            total_samples += answer.size(0)

            avg_loss = total_loss / total_samples
            train_accuracy = total_correct / total_samples
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]

            # if avg_loss == 0:
            #     avg_loss = loss.item()
            #     train_accuracy = accuracy
            # else:
            #     avg_loss = 0.99 * avg_loss + 0.01 * loss.item()
            #     train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy
            # self.total_epoch_loss += loss.item() * answer.size(0)

            dataset.set_description(
                'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format(
                    epoch + 1, avg_loss, train_accuracy))

        self.total_epoch_loss = avg_loss

        dict = {
            "loss": avg_loss,
            "accuracy": train_accuracy,
            "avg_loss": avg_loss,  # For commet
            "avg_accuracy": train_accuracy,  # For commet
        }
        return dict

    def train(self):
        cfg = self.cfg
        print("Start Training")
        for epoch in range(self.start_epoch, self.max_epochs):

            with self.comet_exp.train():
                dict = self.train_epoch(epoch)
                self.reduce_lr()
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            with self.comet_exp.validate():
                dict = self.log_results(epoch, dict)
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            if cfg.TRAIN.EALRY_STOPPING:
                if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch:
                    # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch:
                    print('Early stop')
                    break

        self.comet_exp.log_asset(self.logfile)
        self.save_models(self.max_epochs)
        self.writer.close()
        print("Finished Training")
        print(
            f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}"
        )

    def log_results(self, epoch, dict, max_eval_samples=None):
        epoch += 1
        self.writer.add_scalar("avg_loss", dict["loss"], epoch)
        self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch)

        metrics = self.calc_accuracy("validation",
                                     max_samples=max_eval_samples)
        self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch)
        self.writer.add_scalar("val_accuracy", metrics['acc'], epoch)
        self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch)
        self.writer.add_scalar("val_loss", metrics['loss'], epoch)

        print(
            "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}"
            .format(epoch=epoch, lr=self.lr, **metrics))

        if metrics['acc'] > self.previous_best_acc:
            self.previous_best_acc = metrics['acc']
            self.previous_best_epoch = epoch
        if metrics['loss'] < self.previous_best_loss:
            self.previous_best_loss = metrics['loss']
            self.previous_best_loss_epoch = epoch

        if epoch % self.snapshot_interval == 0:
            self.save_models(epoch)

        return metrics

    def calc_accuracy(self, mode="train", max_samples=None):
        self.set_mode("validation")

        if mode == "train":
            loader = self.dataloader
        # elif (mode == "validation") or (mode == 'test'):
        #     loader = self.dataloader_val
        else:
            loader = self.dataloader_val

        total_correct = 0
        total_correct_ema = 0
        total_samples = 0
        total_loss = 0.
        total_loss_ema = 0.
        pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20)
        for data in pbar:

            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if self.cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()

            with torch.no_grad():
                scores = self.model(image, question, question_len)
                scores_ema = self.model_ema(image, question, question_len)

                loss = self.loss_fn(scores, answer)
                loss_ema = self.loss_fn(scores_ema, answer)

            correct = scores.detach().argmax(1) == answer
            correct_ema = scores_ema.detach().argmax(1) == answer

            total_correct += correct.sum().cpu().item()
            total_correct_ema += correct_ema.sum().cpu().item()

            total_loss += loss.item() * answer.size(0)
            total_loss_ema += loss_ema.item() * answer.size(0)

            total_samples += answer.size(0)

            avg_acc = total_correct / total_samples
            avg_acc_ema = total_correct_ema / total_samples
            avg_loss = total_loss / total_samples
            avg_loss_ema = total_loss_ema / total_samples

            pbar.set_postfix({
                'Acc': f'{avg_acc:.5f}',
                'Acc Ema': f'{avg_acc_ema:.5f}',
                'Loss': f'{avg_loss:.5f}',
                'Loss Ema': f'{avg_loss_ema:.5f}',
            })

        return dict(acc=avg_acc,
                    acc_ema=avg_acc_ema,
                    loss=avg_loss,
                    loss_ema=avg_loss_ema)
Ejemplo n.º 11
0
def main():
    global args, best_acc1
    args = parser.parse_args()

    #########################################################################################
    # Create options
    #########################################################################################
    if args.bert_model == "bert-base-uncased":
        question_features_path = BASE_EXTRACTED_QUES_FEATURES_PATH
    elif args.bert_model == "bert-base-multilingual-cased":
        question_features_path = CASED_EXTRACTED_QUES_FEATURES_PATH
    else:
        question_features_path = EXTRACTED_QUES_FEATURES_PATH

    options = {
        'vqa': {
            'trainsplit': args.vqa_trainsplit
        },
        'logs': {
            'dir_logs': args.dir_logs
        },
        'model': {
            'arch': args.arch,
            'seq2vec': {
                'type': args.st_type,
                'dropout': args.st_dropout,
                'fixed_emb': args.st_fixed_emb
            }
        },
        'optim': {
            'lr': args.learning_rate,
            'batch_size': args.batch_size,
            'epochs': args.epochs
        }
    }
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options_yaml = yaml.load(handle, Loader=yaml.FullLoader)
        options = utils.update_values(options, options_yaml)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)
    if args.help_opt:
        return

    # Set datasets options
    if 'vgenome' not in options:
        options['vgenome'] = None

    #########################################################################################
    # Create needed datasets
    #########################################################################################

    trainset = datasets.factory_VQA(options['vqa']['trainsplit'],
                                    options['vqa'], options['coco'],
                                    options['vgenome'])
    train_loader = trainset.data_loader(
        batch_size=options['optim']['batch_size'],
        num_workers=args.workers,
        shuffle=True)

    if options['vqa']['trainsplit'] == 'train':
        valset = datasets.factory_VQA('val', options['vqa'], options['coco'])
        val_loader = valset.data_loader(
            batch_size=options['optim']['batch_size'],
            num_workers=args.workers)

    if options['vqa']['trainsplit'] == 'trainval' or args.evaluate:
        testset = datasets.factory_VQA('test', options['vqa'], options['coco'])
        test_loader = testset.data_loader(
            batch_size=options['optim']['batch_size'],
            num_workers=args.workers)

    #########################################################################################
    # Create model, criterion and optimizer
    #########################################################################################

    model = models.factory(options['model'],
                           trainset.vocab_words(),
                           trainset.vocab_answers(),
                           cuda=True,
                           data_parallel=True)
    criterion = criterions.factory(options['vqa'], cuda=True)
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        options['optim']['lr'])

    #########################################################################################
    # args.resume: resume from a checkpoint OR create logs directory
    #########################################################################################

    exp_logger = None
    if args.resume:
        args.start_epoch, best_acc1, exp_logger = load_checkpoint(
            model.module, optimizer,
            os.path.join(options['logs']['dir_logs'], args.resume))
    else:
        # Or create logs directory
        if os.path.isdir(options['logs']['dir_logs']):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        options['logs']['dir_logs'], default=False)):
                os.system('rm -r ' + options['logs']['dir_logs'])
            else:
                return
        os.system('mkdir -p ' + options['logs']['dir_logs'])
        path_new_opt = os.path.join(options['logs']['dir_logs'],
                                    os.path.basename(args.path_opt))
        path_args = os.path.join(options['logs']['dir_logs'], 'args.yaml')
        with open(path_new_opt, 'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(path_args, 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

    if exp_logger is None:
        # Set loggers
        exp_name = os.path.basename(
            options['logs']['dir_logs'])  # add timestamp
        exp_logger = logger.Experiment(exp_name, options)
        exp_logger.add_meters('train', make_meters())
        exp_logger.add_meters('test', make_meters())
        if options['vqa']['trainsplit'] == 'train':
            exp_logger.add_meters('val', make_meters())
        exp_logger.info['model_params'] = utils.params_count(model)
        print('Model has {} parameters'.format(
            exp_logger.info['model_params']))

    #########################################################################################
    # args.evaluate: on valset OR/AND on testset
    #########################################################################################

    if args.evaluate:
        path_logger_json = os.path.join(options['logs']['dir_logs'],
                                        'logger.json')

        if options['vqa']['trainsplit'] == 'train':
            acc1, val_results = engine.validate(val_loader, model, criterion,
                                                exp_logger, args.start_epoch,
                                                args.print_freq)
            # save results and compute OpenEnd accuracy
            exp_logger.to_json(path_logger_json)
            save_results(val_results, args.start_epoch, valset.split_name(),
                         options['logs']['dir_logs'], options['vqa']['dir'])

        test_results, testdev_results = engine.test(test_loader, model,
                                                    exp_logger,
                                                    args.start_epoch,
                                                    args.print_freq)
        # save results and DOES NOT compute OpenEnd accuracy
        exp_logger.to_json(path_logger_json)
        save_results(test_results, args.start_epoch, testset.split_name(),
                     options['logs']['dir_logs'], options['vqa']['dir'])
        save_results(testdev_results, args.start_epoch,
                     testset.split_name(testdev=True),
                     options['logs']['dir_logs'], options['vqa']['dir'])
        return

    #########################################################################################
    # Begin training on train/val or trainval/test
    #########################################################################################
    experiment = Experiment(api_key="AgTGwIoRULRgnfVR5M8mZ5AfS",
                            project_name="vqa",
                            workspace="vuhoangminh")
    experiment.log_parameters(flatten(options))

    with experiment.train():
        for epoch in range(args.start_epoch + 1, options['optim']['epochs']):

            engine.train(train_loader, model, criterion, optimizer, exp_logger,
                         epoch, experiment, args.print_freq)

            if options['vqa']['trainsplit'] == 'train':
                # evaluate on validation set
                with experiment.validate():
                    acc1, val_results = engine.validate(
                        val_loader, model, criterion, exp_logger, epoch,
                        args.print_freq)
                    # this will be logged as validation accuracy based on the context.
                    experiment.log_metric("acc1", acc1)

                # remember best prec@1 and save checkpoint
                is_best = acc1 > best_acc1
                best_acc1 = max(acc1, best_acc1)
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': options['model']['arch'],
                        'best_acc1': best_acc1,
                        'exp_logger': exp_logger
                    }, model.module.state_dict(), optimizer.state_dict(),
                    options['logs']['dir_logs'], args.save_model,
                    args.save_all_from, is_best)

                # save results and compute OpenEnd accuracy
                save_results(val_results, epoch, valset.split_name(),
                             options['logs']['dir_logs'],
                             options['vqa']['dir'])

            else:
                test_results, testdev_results = engine.test(
                    test_loader,
                    model,
                    exp_logger,
                    epoch,
                    args.print_freq,
                    topk=3,
                    dict=io_utils.read_pickle(question_features_path),
                    bert_dim=options["model"]["dim_q"])

                # save checkpoint at every timestep
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': options['model']['arch'],
                        'best_acc1': best_acc1,
                        'exp_logger': exp_logger
                    }, model.module.state_dict(), optimizer.state_dict(),
                    options['logs']['dir_logs'], args.save_model,
                    args.save_all_from)

                # save results and DOES NOT compute OpenEnd accuracy
                save_results(test_results, epoch, testset.split_name(),
                             options['logs']['dir_logs'],
                             options['vqa']['dir'])
                save_results(testdev_results, epoch,
                             testset.split_name(testdev=True),
                             options['logs']['dir_logs'],
                             options['vqa']['dir'])
Ejemplo n.º 12
0
    def train(self, model, pair_generator, fold, output_file, use_nprf=False):
        '''Driver function for training

    Args:
      model: a keras Model
      pair_generator: a instantiated pair generator
      fold: which fold to run. partitions will be automatically rotated.
      output_file: temporary file for validation
      use_nprf: whether to use nprf

    Returns:

    '''

        # set tensorflow not to use the full GPU memory
        # session = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
        initial_lrate = self.config.learning_rate
        experiment = Experiment(api_key="PhzBYNpSC304fMjGUoU42dX9b",
                                project_name="nprf-drmm",
                                workspace="neural-ir",
                                auto_param_logging=False)
        experiment_params = {
            "batch_size": self.config.batch_size,
            "optimizer": self.config.optimizer,
            "epochs": self.config.max_iteration,
            "initial_learning_rate": initial_lrate
        }
        experiment.log_multiple_params(experiment_params)

        # qid list config
        qid_list = deque(self.config.qid_list)
        rotate = fold - 1
        map(qid_list.rotate(rotate), qid_list)

        train_qid_list, valid_qid_list, test_qid_list = qid_list[0] + qid_list[
            1] + qid_list[2], qid_list[3], qid_list[4]
        print(train_qid_list, valid_qid_list, test_qid_list)
        relevance_dict = load_pickle(self.config.relevance_dict_path)
        # pair_generator = DDMPairGenerator(**self.config.generator_params)
        nb_pair_train = pair_generator.count_pairs_balanced(
            train_qid_list, self.config.pair_sample_size)

        valid_params = self.eval_by_qid_list_helper(valid_qid_list,
                                                    pair_generator)
        test_params = self.eval_by_qid_list_helper(test_qid_list,
                                                   pair_generator)

        print(valid_params[-1], test_params[-1])
        batch_logger = NBatchLogger(50)
        batch_losses = []
        met = [[], [], [], [], [], []]

        iteration = -1
        best_valid_map = 0.0
        new_lrate = initial_lrate

        for i in range(self.config.nb_epoch):
            print("Epoch " + str(i))

            nb_batch = nb_pair_train / self.config.batch_size

            train_generator = pair_generator.generate_pair_batch(
                train_qid_list, self.config.pair_sample_size)
            for j in range(nb_batch / 100):
                iteration += 1
                new_lrate = self._step_decay(iteration, initial_lrate)
                K.set_value(model.optimizer.lr, new_lrate)

                history = model.fit_generator(
                    generator=train_generator,
                    steps_per_epoch=
                    100,  # nb_pair_train / self.config.batch_size,
                    epochs=1,
                    shuffle=False,
                    verbose=0,
                    callbacks=[batch_logger],
                )
                batch_losses.append(batch_logger.losses)
                print("[Iter {0}]\tLoss: {1}\tlr: {2}".format(
                    iteration, history.history['loss'][0], new_lrate))
                experiment.log_parameter("curr_epoch",
                                         iteration + 1,
                                         step=(iteration + 1))
                experiment.log_parameter("curr_lr",
                                         new_lrate,
                                         step=(iteration + 1))
                experiment.log_metric("curr_loss",
                                      history.history['loss'][0],
                                      step=(iteration + 1))

                kwargs = {
                    'model': model,
                    'relevance_dict': relevance_dict,
                    'rerank_topk': self.config.rerank_topk,
                    'qrels_file': self.config.qrels_file,
                    'docnolist_file': self.config.docnolist_file,
                    'runid': self.config.runid,
                    'output_file': output_file
                }
                if use_nprf:
                    kwargs.update({
                        'nb_supervised_doc': self.config.nb_supervised_doc,
                        'doc_topk_term': self.config.doc_topk_term,
                    })

                valid_met = self.eval_by_qid_list(*valid_params, **kwargs)
                print("[Valid]\t\tMAP\tP20\tNDCG20")
                print("\t\t{0}\t{1}\t{2}".format(valid_met[0], valid_met[1],
                                                 valid_met[2]))
                met[0].append(valid_met[0])
                met[1].append(valid_met[1])
                met[2].append(valid_met[2])

                with experiment.validate():
                    experiment.log_metric("map",
                                          valid_met[0],
                                          step=(iteration + 1))
                    experiment.log_metric("p@20",
                                          valid_met[1],
                                          step=(iteration + 1))
                    experiment.log_metric("ndcg@20",
                                          valid_met[2],
                                          step=(iteration + 1))

                if valid_met[0] > best_valid_map:
                    model.save_weights(
                        os.path.join(self.config.save_path,
                                     "fold{0}.h5".format(fold)))
                    best_valid_map = valid_met[0]

                kwargs['output_file'] = os.path.join(
                    self.config.result_path,
                    "fold{0}.iter{1}.res".format(fold, iteration))
                # test_met = eval_partial(qid_list=test_qid_list)
                test_met = self.eval_by_qid_list(*test_params, **kwargs)
                print("[Test]\t\tMAP\tP20\tNDCG20")
                print("\t\t{0}\t{1}\t{2}".format(test_met[0], test_met[1],
                                                 test_met[2]))
                met[3].append(test_met[0])
                met[4].append(test_met[1])
                met[5].append(test_met[2])

                with experiment.test():
                    experiment.log_metric("map",
                                          test_met[0],
                                          step=(iteration + 1))
                    experiment.log_metric("p@20",
                                          test_met[1],
                                          step=(iteration + 1))
                    experiment.log_metric("ndcg@20",
                                          test_met[2],
                                          step=(iteration + 1))

            print("[Attention]\t\tCurrent best iteration {0}\n".format(
                met[0].index(max(met[0]))))
            if iteration > self.config.max_iteration:
                break
            # model.save_weights(os.path.join(self.config.save_path, "fold{0}.epoch{1}.h5".format(fold, i)))
        best_iter, eval_met = self._extract_max_metric(met)
        retain_file(self.config.result_path, "fold{0}".format(fold),
                    "fold{0}.iter{1}.res".format(fold, best_iter))
        # np.save('loss.npy', batch_losses)
        # np.save('met.npy', met)
        return eval_met
Ejemplo n.º 13
0
class ModelTrainer:
    def __init__(self, model, dataloader, args):
        self.model = model
        self.args = args
        self.data = dataloader
        self.metric = args.metric

        if (dataloader is not None):
            self.frq_log = len(dataloader['train']) // args.frq_log

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        model.to(self.device)

        if args.optimizer == 'sgd':
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            self.optimizer = optim.Adam(model.parameters(),
                                        lr=args.lr,
                                        betas=(args.beta1, 0.999),
                                        weight_decay=args.weight_decay)
        else:
            raise Exception('--optimizer should be one of {sgd, adam}')

        if args.scheduler == 'set':
            self.scheduler = optim.lr_scheduler.LambdaLR(
                self.optimizer,
                lambda epoch: 10**(epoch / args.scheduler_factor))
        elif args.scheduler == 'auto':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode='min',
                factor=args.scheduler_factor,
                patience=5,
                verbose=True,
                threshold=0.0001,
                threshold_mode='rel',
                cooldown=0,
                min_lr=0,
                eps=1e-08)

        self.experiment = Experiment(api_key=args.comet_key,
                                     project_name=args.comet_project,
                                     workspace=args.comet_workspace,
                                     auto_weight_logging=True,
                                     auto_metric_logging=False,
                                     auto_param_logging=False)

        self.experiment.set_name(args.name)
        self.experiment.log_parameters(vars(args))
        self.experiment.set_model_graph(str(self.model))

    def train_one_epoch(self, epoch):

        self.model.train()
        train_loader = self.data['train']
        train_loss = 0
        correct = 0

        comet_offset = epoch * len(train_loader)

        for batch_idx, (data, target) in tqdm(enumerate(train_loader),
                                              leave=True,
                                              total=len(train_loader)):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.cross_entropy(output, target, reduction='sum')
            loss.backward()
            self.optimizer.step()

            pred = output.argmax(dim=1, keepdim=True)
            acc = pred.eq(target.view_as(pred)).sum().item()
            train_loss += loss.item()
            correct += acc

            loss = loss.item() / len(data)
            acc = 100. * acc / len(data)

            comet_step = comet_offset + batch_idx
            self.experiment.log_metric('batch_loss', loss, comet_step, epoch)
            self.experiment.log_metric('batch_acc', acc, comet_step, epoch)

            if (batch_idx + 1) % self.frq_log == 0:
                self.experiment.log_metric('log_loss', loss, comet_step, epoch)
                self.experiment.log_metric('log_acc', acc, comet_step, epoch)
                print('Epoch: {} [{}/{}]\tLoss: {:.6f}\tAcc: {:.2f}%'.format(
                    epoch + 1, (batch_idx + 1) * len(data),
                    len(train_loader.dataset), loss, acc))

        train_loss /= len(train_loader.dataset)
        acc = 100. * correct / len(train_loader.dataset)

        comet_step = comet_offset + len(train_loader) - 1
        self.experiment.log_metric('loss', train_loss, comet_step, epoch)
        self.experiment.log_metric('acc', acc, comet_step, epoch)

        print(
            'Epoch: {} [Done]\tLoss: {:.4f}\tAccuracy: {}/{} ({:.2f}%)'.format(
                epoch + 1, train_loss, correct, len(train_loader.dataset),
                acc))

        return {'loss': train_loss, 'acc': acc}

    def train(self):

        self.log_cmd()
        best = -1
        history = {'lr': [], 'train_loss': []}

        try:
            print(">> Training %s" % self.model.name)
            for epoch in range(self.args.nepoch):
                with self.experiment.train():
                    train_res = self.train_one_epoch(epoch)

                with self.experiment.validate():
                    print("\nvalidation...")
                    comet_offset = (epoch + 1) * len(self.data['train']) - 1
                    res = self.val(self.data['val'], comet_offset, epoch)

                if res[self.metric] > best:
                    best = res[self.metric]
                    self.save_weights(epoch)

                if self.args.scheduler == 'set':
                    lr = self.optimizer.param_groups[0]['lr']
                    history['lr'].append(lr)
                    history['train_loss'].append(train_res['loss'])

                    self.scheduler.step(epoch + 1)
                    lr = self.optimizer.param_groups[0]['lr']
                    print('learning rate changed to: %.10f' % lr)

                elif self.args.scheduler == 'auto':
                    self.scheduler.step(train_res['loss'])
        finally:
            print(">> Training model %s. [Stopped]" % self.model.name)
            self.experiment.log_asset_folder(os.path.join(
                self.args.outf, self.args.name, 'weights'),
                                             step=None,
                                             log_file_name=False,
                                             recursive=False)
            if self.args.scheduler == 'set':
                plt.semilogx(history['lr'], history['train_loss'])
                plt.grid(True)
                self.experiment.log_figure(figure=plt)
                plt.show()

    def val(self, val_loader, comet_offset=-1, epoch=-1):
        self.model.eval()
        test_loss = 0
        correct = 0

        labels = list(range(self.args.nclass))
        cm = np.zeros((len(labels), len(labels)))

        with torch.no_grad():
            for data, target in tqdm(val_loader,
                                     leave=True,
                                     total=len(val_loader)):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.cross_entropy(output, target,
                                             reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

                pred = pred.view_as(target).data.cpu().numpy()
                target = target.data.cpu().numpy()
                cm += confusion_matrix(target, pred, labels=labels)

        test_loss /= len(val_loader.dataset)
        accuracy = 100. * correct / len(val_loader.dataset)

        print('Evaluation: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.
              format(test_loss, correct, len(val_loader.dataset), accuracy))

        res = {'loss': test_loss, 'acc': accuracy}

        self.experiment.log_metrics(res, step=comet_offset, epoch=epoch)
        self.experiment.log_confusion_matrix(
            matrix=cm,
            labels=[ClassDict.getName(x) for x in labels],
            title='confusion matrix after epoch %03d' % epoch,
            file_name="confusion_matrix_%03d.json" % epoch)

        return res

    def test(self):
        self.load_weights()
        with self.experiment.test():
            print('\ntesting....')
            res = self.val(self.data['test'])

    def log_cmd(self):
        d = vars(self.args)
        cmd = '!python main.py \\\n'
        tab = '    '

        for k, v in d.items():
            if v is None or v == '' or (isinstance(v, bool) and v is False):
                continue

            if isinstance(v, bool):
                arg = '--{} \\\n'.format(k)
            else:
                arg = '--{} {} \\\n'.format(k, v)

            cmd = cmd + tab + arg

        # print(cmd);
        self.experiment.log_text(cmd)

    def save_weights(self, epoch: int):

        weight_dir = os.path.join(self.args.outf, self.args.name, 'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch,
            'state_dict': self.model.state_dict()
        }, os.path.join(weight_dir, 'model.pth'))

    def load_weights(self):

        path_g = self.args.weights_path

        if path_g is None:
            weight_dir = os.path.join(self.args.outf, self.args.name,
                                      'weights')
            path_g = os.path.join(weight_dir, 'model.pth')

        print('>> Loading weights...')
        weights_g = torch.load(path_g, map_location=self.device)['state_dict']
        self.model.load_state_dict(weights_g)
        print('   Done.')

    def predict(self, x):
        x = x / 2**15
        self.model.eval()
        with torch.no_grad():
            x = torch.from_numpy(x).float()
            x = self.transform(x)
            x = x.unsqueeze(0)
            x = self.model(x)
            x = F.softmax(x, dim=1)
            x = x.numpy()
        return x
Ejemplo n.º 14
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = "trainval" if opt.use_trainval else "train"
    if opt.dataset == "miniImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            ImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            ImageNet(args=opt, partition="val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        # meta_testloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="test",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        # meta_valloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="val",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == "tieredImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            TieredImageNet(args=opt, partition="train_phase_val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == "CIFAR-FS" or opt.dataset == "FC100":
        train_trans, test_trans = transforms_options["D"]

        train_loader = DataLoader(
            CIFAR100(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            CIFAR100(args=opt, partition="train", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == "CIFAR-FS":
                n_cls = 64
            elif opt.dataset == "FC100":
                n_cls = 60
            else:
                raise NotImplementedError(
                    "dataset not supported: {}".format(opt.dataset)
                )
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=opt.learning_rate, weight_decay=0.0005
        )
    else:
        optimizer = optim.SGD(
            model.parameters(),
            lr=opt.learning_rate,
            momentum=opt.momentum,
            weight_decay=opt.weight_decay,
        )

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    comet_logger = Experiment(
        api_key=os.environ["COMET_API_KEY"],
        project_name=opt.comet_project_name,
        workspace=opt.comet_workspace,
        disabled=not opt.logcomet,
    )
    comet_logger.set_name(opt.model_name)
    comet_logger.log_parameters(vars(opt))

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** opt.cosine_factor)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1
        )

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        with comet_logger.train():
            train_acc, train_loss = train(
                epoch, train_loader, model, criterion, optimizer, opt
            )
            comet_logger.log_metrics(
                {"acc": train_acc.cpu(), "loss_epoch": train_loss}, epoch=epoch
            )
        time2 = time.time()
        print("epoch {}, total time {:.2f}".format(epoch, time2 - time1))

        logger.log_value("train_acc", train_acc, epoch)
        logger.log_value("train_loss", train_loss, epoch)

        with comet_logger.validate():
            test_acc, test_acc_top5, test_loss = validate(
                val_loader, model, criterion, opt
            )
            comet_logger.log_metrics(
                {"acc": test_acc.cpu(), "acc_top5": test_acc_top5.cpu(), "loss": test_loss,},
                epoch=epoch,
            )

        logger.log_value("test_acc", test_acc, epoch)
        logger.log_value("test_acc_top5", test_acc_top5, epoch)
        logger.log_value("test_loss", test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print("==> Saving...")
            state = {
                "epoch": epoch,
                "model": model.state_dict()
                if opt.n_gpu <= 1
                else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch)
            )
            torch.save(state, save_file)

    # save the last model
    state = {
        "opt": opt,
        "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model))
    torch.save(state, save_file)
Ejemplo n.º 15
0
                                        cd_corrects,
                                        cd_train_report)

            # log the batch mean metrics
            mean_train_metrics = get_mean_metrics(train_metrics)
            comet.log_metrics(mean_train_metrics)

            # clear batch variables from memory
            del batch_img1, batch_img2, labels

        print("EPOCH TRAIN METRICS", mean_train_metrics)

    """
    Begin Validation
    """
    with comet.validate():
        model.eval()

        first_batch = True
        for batch_img1, batch_img2, labels in val_loader:
            # Set variables for training
            batch_img1 = autograd.Variable(batch_img1).to(dev)
            batch_img2 = autograd.Variable(batch_img2).to(dev)
            labels = autograd.Variable(labels).long().to(dev)

            # Get predictions and calculate loss
            cd_preds = model(batch_img1, batch_img2)
            cd_loss = criterion(cd_preds, labels)
            _, cd_preds = torch.max(cd_preds, 1)

            # If this is the first batch, comet log the loss to gauge results
Ejemplo n.º 16
0
                    loss, accuracy = model.train_on_batch(
                        train_text[train_step], train_labels[train_step])
                    train_loss.append(loss)
                    train_accuracy.append(accuracy)

                    experiment.log_metric('loss',
                                          np.mean(train_loss),
                                          step=global_step)
                    experiment.log_metric('accuracy',
                                          np.mean(train_accuracy),
                                          step=global_step)

                    # Every evaluate_steps evaluate model on validation set
                    if (train_step + 1) % evaluate_steps == 0 or (
                            train_step + 1) == train_steps:
                        with experiment.validate():
                            for val_step in range(val_steps):

                                # Perform evaluation step on batch and record metrics
                                loss, accuracy = model.test_on_batch(
                                    val_text[val_step], val_labels[val_step])
                                val_loss.append(loss)
                                val_accuracy.append(accuracy)

                                experiment.log_metric('loss',
                                                      np.mean(val_loss),
                                                      step=global_step)
                                experiment.log_metric('accuracy',
                                                      np.mean(val_accuracy),
                                                      step=global_step)
Ejemplo n.º 17
0
def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args.gpu_idx)
    # if len(args.gpu_idx):
    #     device = torch.device("cuda" )
    #     print(device)
    # else:
    #     device = torch.device("cpu")

    # dataloader
    input_channel = 1
    if args.dataset == 'fashionmnist':
        args.num_classes = 2

        # args.train_path = '/storage/fei/data/'
        # args.val_path = '/storage/fei/data/'
        # transform = transforms.Compose([transforms.ToTensor(),
        #                                 transforms.Normalize((0.1307,), (0.3081,))])
        #
        # train_set = torchvision.datasets.FashionMNIST(
        #     root=args.train_path,
        #     train=True,
        #     transform=transform
        # )
        # val_set = torchvision.datasets.FashionMNIST(
        #     root=args.val_path,
        #     train=False,
        #     transform=transform
        # )

        from keras.datasets import fashion_mnist
        (trainX, trainy), (testX, testy) = fashion_mnist.load_data()
        # train_set = fashionMNIST(trainX, trainy, real=[5, 7, 9], fake=[0, 1, 2, 3])
        # val_set = fashionMNIST(testX, testy, real=[5, 7, 9], fake=[0, 1, 2, 3, 4, 6, 8])

        real = [3]
        fake_val = [0, 2]
        fake_test = [4, 6]

        train_set = fashionMNIST(trainX, trainy, real=real, fake=fake_val)
        val_set = fashionMNIST(testX, testy, real=real, fake=fake_val)
        test_set = fashionMNIST(testX, testy, real=real, fake=fake_test)
    else:
        raise ValueError(
            'Dataset should be: voxceleb1, imagenet, fashionmnist!')
    #
    train_dataloader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4,
                                  drop_last=True)
    val_dataloader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)
    test_dataloader = DataLoader(test_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)

    model = Model_base(args).cuda()

    experiment = Experiment(API_KEY, project_name='OC-Softmax')
    experiment.log_parameters(vars(args))
    experiment.set_name(args.model_dir)
    numparams = 0
    for f in model.parameters():
        if f.requires_grad:
            numparams += f.numel()
    experiment.log_parameter('Parameters', numparams)
    print('Total number of parameters: {}'.format(numparams))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model.backbone = nn.DataParallel(model.backbone)

    model = model.cuda()

    # Optimizer
    # optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    # optimizer = optim.SGD([{'params': model.backbone.parameters()},
    #                        {'params': model.softmax_layer.parameters()}],
    #                       lr=args.lr, momentum=0.9, nesterov=False)
    # scheduler = StepLR(optimizer, step_size=30, gamma=0.6)

    # Save config
    model_path = os.path.join(args.exp_dir, args.model_dir)
    log_path = os.path.join(model_path, 'logs')
    if os.path.exists(log_path):
        res = input("Experiment {} already exists, continue? (y/n)".format(
            args.model_dir))
        print(res)
        if res == 'n':
            sys.exit()
    os.makedirs(log_path, exist_ok=True)
    conf_path = os.path.join(log_path, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(vars(args), outfile)

    log_file = '{}/stats.csv'.format(log_path)
    log_content = [
        'Epoch', 'tr_acc', 'val_acc', 'test_acc', 'val_eer', 'test_eer',
        'tr_loss', 'val_loss', 'test_loss'
    ]

    if not os.path.exists(log_file):
        with open(log_file, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(log_content)

    # Train model
    tr_step = 0
    val_step = 0
    new_lr = args.lr
    halving = False
    best_val_loss = float("-inf")
    val_no_impv = 0

    # training
    iteration = 0
    tr_step = 0
    for epoch in range(args.epochs):
        metric_dic = {}
        for m in log_content[1:]:
            metric_dic[m] = []
        current_lr = adjust_learning_rate(optimizer, tr_step, args.lr)
        # print('Epoch:', epoch,'LR:', optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr'])
        print('Epoch: {}, learning rate: {}'.format(epoch + 1, current_lr))
        # train_utils.val_step(spk_classifier, embedding, val_dataloader,  iteration, val_log_path)

        # Training
        model.train()
        for data in tqdm(train_dataloader,
                         desc='{} Training'.format(
                             args.model_dir)):  # mini-batch
            # one batch of training data
            # input_feature, target = data['input_feature'].to(device), data['target'].to(device)
            input_feature, target = data[0].cuda(), data[1].cuda()

            # gradient accumulates
            optimizer.zero_grad()

            # embedding
            # embeddings = model.backbone(input_feature)
            output, loss = model(input_feature, target)
            metric_dic['tr_loss'].append(loss.detach().cpu())

            # if args.center > 0:
            #     l_c = 0
            #     for i in range(model.embeddings.shape[0]):
            #         l_c = l_c + 0.5 * (model.embeddings[i] - W[:, target[i]]).pow(2).sum()
            #     l_c = l_c / model.embeddings.shape[0]
            #     loss = loss + args.center * l_c
            #     metric_dic['tr_center_loss'].append(l_c.detch().cpu())
            #
            # if args.w_ortho > 0:
            #     W = F.normalize(model.softmax_layer.W, p=2, dim=0)
            #     l_w_reg = (W.T @ W - torch.eye(W.shape[1]).cuda()).norm(p=2)
            #     loss = loss + args.w_ortho * l_w_reg
            #     metric_dic['tr_w_reg'].append(l_w_reg.detach().cpu())

            train_acc = utils.accuracy(output, target)[0]  # Top-1 acc
            metric_dic['tr_acc'].append(train_acc.cpu())

            loss.backward()
            #             torch.nn.utils.clip_grad_norm_(embedding.parameters(), 1.0)
            #             torch.nn.utils.clip_grad_norm_(spk_classifier.parameters(), 1.0)
            optimizer.step()

            if iteration % 100 == 0:
                print('Train loss: {:.2f}, Acc: {:.2f}%'.format(
                    loss.item(), train_acc))

            iteration += 1
        tr_step += 1

        # res_dic['tr_loss']['acc'] += l.tolist()

        # Validation
        if val_dataloader is not None:
            model.eval()
            outputs = []
            targets = []
            with torch.no_grad():
                for data in tqdm(val_dataloader,
                                 desc='Validation'):  # mini-batch
                    # input_feature, target = data['input_feature'].to(device), data['target'].to(device)
                    input_feature, target = data[0].cuda(), data[1].cuda()

                    output, loss = model(input_feature, target)

                    # val_acc = utils.accuracy(output, target)[0] # Top-1 acc
                    # metric_dic['val_acc'].append(val_acc.cpu())
                    metric_dic['val_loss'].append(loss.cpu())
                    outputs.append(output)
                    targets.append(target)
            metric_dic['val_acc'] = utils.accuracy(
                torch.cat(outputs).cpu(),
                torch.cat(targets).cpu())[0]
            metric_dic['val_acc'] = metric_dic['val_acc'].item()

            eer1, _ = utils.compute_eer(
                torch.cat(outputs).cpu()[:, 0],
                torch.cat(targets).cpu())
            eer2, _ = utils.compute_eer(-torch.cat(outputs).cpu()[:, 0],
                                        torch.cat(targets).cpu())
            metric_dic['val_eer'] = min(eer1, eer2)

        # Test
        if test_dataloader is not None:
            model.eval()
            outputs = []
            targets = []
            with torch.no_grad():
                for data in tqdm(test_dataloader,
                                 desc='Validation'):  # mini-batch
                    # input_feature, target = data['input_feature'].to(device), data['target'].to(device)
                    input_feature, target = data[0].cuda(), data[1].cuda()

                    output, loss = model(input_feature, target)

                    # val_acc = utils.accuracy(output, target)[0] # Top-1 acc
                    # metric_dic['val_acc'].append(val_acc.cpu())
                    metric_dic['test_loss'].append(loss.cpu())
                    outputs.append(output)
                    targets.append(target)
            metric_dic['test_acc'] = utils.accuracy(
                torch.cat(outputs).cpu(),
                torch.cat(targets).cpu())[0]
            metric_dic['test_acc'] = metric_dic['test_acc'].item()

            eer1, _ = utils.compute_eer(
                torch.cat(outputs).cpu()[:, 0],
                torch.cat(targets).cpu())
            eer2, _ = utils.compute_eer(-torch.cat(outputs).cpu()[:, 0],
                                        torch.cat(targets).cpu())
            metric_dic['test_eer'] = min(eer1, eer2)

        for metric in metric_dic.keys():
            if isinstance(metric_dic[metric], list):
                metric_dic[metric] = np.mean(metric_dic[metric])
            if metric[:3] == 'tr_':
                with experiment.train():
                    experiment.log_metric(metric[3:],
                                          metric_dic[metric],
                                          step=tr_step)
            if metric[:4] == 'val_':
                with experiment.validate():
                    experiment.log_metric(metric[4:],
                                          metric_dic[metric],
                                          step=tr_step)

        pprint(metric_dic)

        # Write logs
        with open(log_file, 'a') as f:
            writer = csv.writer(f)
            write_content = [tr_step
                             ] + [metric_dic[m] for m in metric_dic.keys()]
            writer.writerow(write_content)

        Model_base.save_if_best(save_dir=model_path,
                                model=model,
                                optimizer=optimizer,
                                epoch=tr_step,
                                tr_metric=metric_dic['tr_acc'],
                                val_metric=metric_dic['val_eer'],
                                metric_name='eer',
                                save_every=10)