Exemple #1
0
 def on_epoch_end(self, epoch, logs=None):
     logs = logs or {}
     self.epochs_since_last_save += 1
     if self.epochs_since_last_save >= self.period:
         self.epochs_since_last_save = 0
         if self.save_best_only:
             current = logs.get(self.monitor)
             if current is None:
                 logging.warning('Can save best model only with %s available, '
                                 'skipping.' % (self.monitor), RuntimeWarning)
             else:
                 if self.monitor_op(current, self.best):
                     if self.verbose > 0:
                         print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                               ' saving model to %s'
                               % (epoch, self.monitor, self.best,
                         current, self.filepath))
                     self.best = current
                     save_weights(self.model.model, self.optimizer, self.filepath)
                 else:
                     if self.verbose > 0:
                         print('Epoch %05d: %s did not improve' %
                               (epoch, self.monitor))
         else:
             if self.verbose > 0:
                 print('Epoch %05d: saving model to %s' % (epoch, self.filepath))
                 save_weights(self.model.model, self.optimizer, self.filepath)
Exemple #2
0
def training_loop(model,
                  loss_function,
                  metrics,
                  optimizer,
                  meta_data,
                  config,
                  save_path,
                  train,
                  valid,
                  custom_callbacks=[],
                  checkpoint_monitor="val_acc",
                  use_tb=False,
                  reload=True,
                  n_epochs=100,
                  save_freq=1,
                  save_history_every_k_examples=-1,
                  device=None):
    callbacks = list(custom_callbacks)

    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if reload:
        H, epoch_start = _reload(model, optimizer, save_path, callbacks)
    else:
        save_weights(model, optimizer,
                     os.path.join(save_path, "init_weights.pt"))

        history_csv_path, history_pkl_path = os.path.join(
            save_path, "history.csv"), os.path.join(save_path, "history.pkl")
        logger.info("Removing {} and {}".format(history_pkl_path,
                                                history_csv_path))
        os.system("rm " + history_pkl_path)
        os.system("rm " + history_csv_path)
        H, epoch_start = {}, 0

    callbacks += _construct_default_callbacks(model, optimizer, H, save_path,
                                              checkpoint_monitor, save_freq,
                                              custom_callbacks, use_tb,
                                              save_history_every_k_examples)

    # Configure callbacks
    for clbk in callbacks:
        clbk.set_save_path(save_path)
        clbk.set_model(model, ignore=False)  # TODO: Remove this trick
        clbk.set_optimizer(optimizer)
        clbk.set_meta_data(meta_data)
        clbk.set_config(config)

    _training_loop(model,
                   valid,
                   train,
                   optimizer,
                   loss_function,
                   epoch_start,
                   n_epochs,
                   callbacks,
                   metrics=metrics,
                   device=device)
Exemple #3
0
    def run_epoch(set_key: str,
                  i_ep: int,
                  all_ind: np.ndarray,
                  train: bool,
                  batches_per_epoch: int,
                  lr_step: float = 0.0):
        torch.cuda.empty_cache()
        if train:
            model.train()
        else:
            model.eval()

        metrics = {}
        counts = Counter()

        for batch_i in tqdm(range(batches_per_epoch),
                            desc=f'{save_path} {set_key} epoch {i_ep + 1}'):
            if lr_step > 0:
                set_lr(get_lr() + lr_step)
            try:
                batch_metrics = run_batch(all_ind, train)
                for k, v in batch_metrics.items():
                    if k not in metrics:
                        metrics[k] = 0
                    metrics[k] += v
                    counts[k] += 1
            except AssertionError as e:
                # batch skipped because of zero loss
                logger.debug(f"Exception while running batch: {str(e)}")
            except Exception as e:
                logger.warning(f"Exception while running batch: {str(e)}")
                raise e

        metrics = dict((k, v / counts[k]) for k, v in metrics.items())
        str_metrics = ', '.join("{:s}={:.4f}".format(k, v)
                                for k, v in metrics.items())
        logger.info(f'{set_key} epoch {i_ep + 1}: {str_metrics}')

        if train:
            save_weights(model_path,
                         model,
                         optimizer,
                         epoch=i_ep,
                         lr=get_lr(),
                         no_progress=no_progress)

            if save_each_epoch:
                model_epoch_path = os.path.join(
                    checkpoints_path,
                    f'model_{(i_ep + 1) * train_samples_per_epoch}.pt')
                save_weights(model_epoch_path,
                             model,
                             optimizer,
                             epoch=i_ep,
                             lr=get_lr())
        return metrics
def train(model,
          optimizer,
          scheduler,
          dataset,
          cfg,
          val_dataset=None,
          vis=True):

    training_loss_list = []
    test_loss_list = []

    for epoch in range(cfg['max_epoch']):
        true_nums = 0
        for index, (inputs, label) in enumerate(dataset):
            inputs = img_preprocess(inputs)
            outputs = model.forward(inputs)
            loss, reg_loss = model.compute_loss(label)
            grads = model.backward()
            optimizer.step(grads)

            true_num, precision = cal_precision(outputs, label)
            true_nums += true_num
            logging.info(
                "[%d/%d] train loss: %.2f, reg loss: %.2f, total loss: %.4f, precision %.4f || lr: %.6f"
                % (epoch, index, loss, reg_loss,
                   (loss + reg_loss), precision, optimizer.lr))
        scheduler.step()
        params_path = save_weights(model.params, cfg['workspace'], model.name,
                                   epoch)
        logging.info("save model at: %s, training precision %.4f" %
                     (params_path, true_nums / dataset.total))
        training_loss_list.append(true_nums / dataset.total)

        if val_dataset is not None:
            loss = val(model, model.name, params_path, val_dataset)
            test_loss_list.append(loss)

    if vis:
        draw_loss_graph(cfg['workspace'] + "/loss.png", training_loss_list,
                        test_loss_list)
Exemple #5
0
 def save_weights_fnc(epoch, logs):
     if epoch % save_freq == 0:
         logger.info("Saving model from epoch " + str(epoch))
         save_weights(model, optimizer,
                      os.path.join(save_path, "model_last_epoch.pt"))
Exemple #6
0
def training_loop(model,
                  loss_function,
                  metrics,
                  optimizer,
                  scheduler,
                  meta_data,
                  config,
                  save_path,
                  steps_per_epoch,
                  n_epochs,
                  hyper_optim=None,
                  train=None,
                  valid=None,
                  test=None,
                  validation_per_epoch=1,
                  data_loader=None,
                  meta_data_loader=None,
                  test_steps=None,
                  validation_steps=None,
                  use_gpu=False,
                  device_numbers=[0],
                  pretrained=False,
                  BIRADS_pretrained_weights_path=None,
                  pretrained_weight_paths=None,
                  custom_callbacks=[],
                  checkpoint_monitor="val_acc",
                  use_tb=False,
                  reload=True,
                  save_freq=1,
                  save_history_every_k_examples=-1,
                  fb_method=False,
                  target_indice=None,
                  grad_norm_penalty=None,
                  penalty_on_columns=False,
                  task_name=None,
                  grad_norm_penalty_var='vanilla',
                  num_of_view_to_change=1,
                  suppress_nan_labels_in_loss=True):

    callbacks = [BaseLogger()] + list(custom_callbacks)

    if pretrained and pretrained_weight_paths is not None:
        _load_pretrained_branch(model, pretrained_weight_paths)
    elif BIRADS_pretrained_weights_path is not None:
        _load_BIRADS_pretrained(model, BIRADS_pretrained_weights_path)

    if reload:
        H, epoch_start = _reload(model, optimizer, save_path, custom_callbacks)
    else:
        save_weights(model, optimizer,
                     os.path.join(save_path, "init_weights.pt"))

        history_csv_path, history_pkl_path = os.path.join(
            save_path, "history.csv"), os.path.join(save_path, "history.pkl")
        logger.info("Removing {} and {}".format(history_pkl_path,
                                                history_csv_path))
        os.system("rm " + history_pkl_path)
        os.system("rm " + history_csv_path)
        H, epoch_start = {}, 0

    H_batch = {}

    if train is not None:
        default_callbacks = _construct_default_callbacks(
            model, optimizer, H, H_batch, save_path, checkpoint_monitor,
            save_freq, custom_callbacks, use_tb, save_history_every_k_examples)

    else:
        # If train is None, then evaluation. However, training_loop is not designed for this purpose.
        # You should use evaluation_loop when setting train to None
        default_callbacks = _construct_default_eval_callbacks(
            H, H_batch, save_path, save_history_every_k_examples)

    callbacks = callbacks + default_callbacks

    # Configure callbacks
    for clbk in callbacks:
        clbk.set_save_path(save_path)
        clbk.set_model(model, ignore=False)  # TODO: Remove this trick
        clbk.set_optimizer(optimizer)
        clbk.set_meta_data(meta_data)
        clbk.set_config(config)
        clbk.set_dataloader(None)

    is_multi_gpu = False

    if use_gpu and torch.cuda.is_available():
        is_multi_gpu = len(device_numbers) > 1
        base_device = torch.device("cuda:{}".format(device_numbers[0]))

        if is_multi_gpu:
            model = torch.nn.DataParallel(model, device_ids=device_numbers)

        logger.info("Sending model to {}".format(base_device))
        model.to(base_device)
        optimizer.load_state_dict(
            optimizer.state_dict()
        )  # Hack to use right device for optimizer, according to https://github.com/pytorch/pytorch/issues/8741

    _loop(model,
          test,
          valid,
          train,
          optimizer,
          scheduler,
          loss_function,
          initial_epoch=epoch_start,
          epochs=n_epochs,
          callbacks=callbacks,
          metrics=metrics,
          device=base_device,
          steps_per_epoch=steps_per_epoch,
          suppress_nan_labels_in_loss=suppress_nan_labels_in_loss)
Exemple #7
0
def training_loop(model,
                  loss_function,
                  metrics,
                  optimizer,
                  meta_data,
                  config,
                  save_path,
                  train,
                  valid,
                  steps_per_epoch,
                  custom_callbacks=[],
                  checkpoint_monitor="val_acc",
                  use_tb=False,
                  reload=True,
                  n_epochs=100,
                  save_freq=1,
                  save_history_every_k_examples=-1):
    callbacks = list(custom_callbacks)

    if reload:
        H, epoch_start = _reload(model, optimizer, save_path, callbacks)
    else:
        save_weights(model, optimizer,
                     os.path.join(save_path, "init_weights.pt"))

        history_csv_path, history_pkl_path = os.path.join(
            save_path, "history.csv"), os.path.join(save_path, "history.pkl")
        logger.info("Removing {} and {}".format(history_pkl_path,
                                                history_csv_path))
        os.system("rm " + history_pkl_path)
        os.system("rm " + history_csv_path)
        H, epoch_start = {}, 0

    callbacks += _construct_default_callbacks(model, optimizer, H, save_path,
                                              checkpoint_monitor, save_freq,
                                              custom_callbacks, use_tb,
                                              save_history_every_k_examples)

    # Configure callbacks
    for clbk in callbacks:
        clbk.set_save_path(save_path)
        clbk.set_model(model, ignore=False)  # TODO: Remove this trick
        clbk.set_optimizer(optimizer)
        clbk.set_meta_data(meta_data)
        clbk.set_config(config)

    model = Model(model=model,
                  optimizer=optimizer,
                  loss_function=loss_function,
                  metrics=metrics)
    if torch.cuda.is_available():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        logger.info("Sending model to {}".format(device))
        model.to(device)

    _ = model.fit_generator(
        train,
        initial_epoch=epoch_start,
        steps_per_epoch=steps_per_epoch,
        epochs=n_epochs - 1,  # Weird convention
        verbose=1,
        valid_generator=valid,
        callbacks=callbacks)
def main(args):
    """
    Main function to execute the training.
    Performs training, validation after each epoch and testing after full epoch training.
    :param args: input command line arguments which will set the learning rate, number of epochs, data root etc.
    :return: None
    """

    sess_path = utils.create_session(args.log_dir, args.session_id)  # Create a session path based on the session id.
    G = tf.Graph()
    with G.as_default():
        # Create image and density map placeholder
        image_place_holder = tf.placeholder(tf.float32, shape=[1, None, None, 1])
        d_map_place_holder = tf.placeholder(tf.float32, shape=[1, None, None, 1])

        # Build all nodes of the network
        d_map_est = mccnn.build(image_place_holder)

        # Define the loss function.
        euc_loss = L.loss(d_map_est, d_map_place_holder)

        # Define the optimization algorithm
        optimizer = tf.train.GradientDescentOptimizer(args.learning_rate)

        # Training node.
        train_op = optimizer.minimize(euc_loss)

        # Initialize all the variables.
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        # For summary
        summary = tf.summary.merge_all()

        with tf.Session(graph=G) as sess:
            writer = tf.summary.FileWriter(os.path.join(sess_path,'training_logging'))
            writer.add_graph(sess.graph)
            sess.run(init)

            #if args.retrain:
            #    utils.load_weights(G, args.base_model_path)


            # Start the epochs
            for eph in range(args.num_epochs):

                start_train_time = time.time()

                # Get the list of train images.
                train_images_list, train_gts_list = utils.get_data_list(args.data_root, mode='train')
                total_train_loss = 0

                # Loop through all the training images
                for img_idx in range(len(train_images_list)):

                    # Load the image and ground truth
                    train_image = np.asarray(mpimg.imread(train_images_list[img_idx]), dtype=np.float32)
                    train_d_map = np.asarray(sio.loadmat(train_gts_list[img_idx])['d_map'], dtype=np.float32)

                    # Reshape the tensor before feeding it to the network
                    train_image_r = utils.reshape_tensor(train_image)
                    train_d_map_r = utils.reshape_tensor(train_d_map)

                    # Prepare feed_dict
                    feed_dict_data = {
                        image_place_holder: train_image_r,
                        d_map_place_holder: train_d_map_r,
                    }

                    # Compute the loss for one image.
                    _, loss_per_image = sess.run([train_op, euc_loss], feed_dict=feed_dict_data)

                    # Accumalate the loss over all the training images.
                    total_train_loss = total_train_loss + loss_per_image

                end_train_time = time.time()
                train_duration = end_train_time - start_train_time

                # Compute the average training loss
                avg_train_loss = total_train_loss / len(train_images_list)

                # Then we print the results for this epoch:
                print("Epoch {} of {} took {:.3f}s".format(eph + 1, args.num_epochs, train_duration))
                print("  Training loss:\t\t{:.6f}".format(avg_train_loss))


                print ('Validating the model...')

                total_val_loss = 0

                # Get the list of images and the ground truth
                val_image_list, val_gt_list = utils.get_data_list(args.data_root, mode='valid')

                valid_start_time = time.time()

                # Loop through all the images.
                for img_idx in xrange(len(val_image_list)):

                    # Read the image and the ground truth
                    val_image = np.asarray(mpimg.imread(val_image_list[img_idx]), dtype=np.float32)
                    val_d_map = np.asarray(sio.loadmat(val_gt_list[img_idx])['d_map'], dtype=np.float32)

                    # Reshape the tensor for feeding it to the network
                    val_image_r = utils.reshape_tensor(val_image)
                    val_d_map_r = utils.reshape_tensor(val_d_map)

                    # Prepare the feed_dict
                    feed_dict_data = {
                        image_place_holder: val_image_r,
                        d_map_place_holder: val_d_map_r,
                    }

                    # Compute the loss per image
                    loss_per_image = sess.run(euc_loss, feed_dict=feed_dict_data)

                    # Accumalate the validation loss across all the images.
                    total_val_loss = total_val_loss + loss_per_image

                valid_end_time = time.time()
                val_duration = valid_end_time - valid_start_time

                # Compute the average validation loss.
                avg_val_loss = total_val_loss / len(val_image_list)

                print("  Validation loss:\t\t{:.6f}".format(avg_val_loss))
                print ("Validation over {} images took {:.3f}s".format(len(val_image_list), val_duration))

                # Save the weights as well as the summary
                utils.save_weights(G, os.path.join(sess_path, "weights.%s" % (eph+1)))
                summary_str = sess.run(summary, feed_dict=feed_dict_data)
                writer.add_summary(summary_str, eph)


            print ('Testing the model with test data.....')

            # Get the image list
            test_image_list, test_gt_list = utils.get_data_list(args.data_root, mode='test')
            abs_err = 0

            # Loop through all the images.
            for img_idx in xrange(len(test_image_list)):

                # Read the images and the ground truth
                test_image = np.asarray(mpimg.imread(test_image_list[img_idx]), dtype=np.float32)
                test_d_map = np.asarray(sio.loadmat(test_gt_list[img_idx])['d_map'], dtype=np.float32)                

                # Reshape the input image for feeding it to the network.
                test_image = utils.reshape_tensor(test_image)
                feed_dict_data = {image_place_holder: test_image}

                # Make prediction.
                pred = sess.run(d_map_est, feed_dict=feed_dict_data)                

                # Compute mean absolute error.
                abs_err += utils.compute_abs_err(pred, test_d_map)

            # Average across all the images.
            avg_mae = abs_err / len(test_image_list)
            print ("Mean Absolute Error over the Test Set: %s" %(avg_mae))
            print ('Finished.')
Exemple #9
0
def train_megan(save_path: str,
                featurizer_key: str,
                learning_rate: float = 0.0001,
                train_samples_per_epoch: int = -1,
                valid_samples_per_epoch: int = -1,
                batch_size: int = 4,
                gen_lr_factor: float = 0.1,
                gen_lr_patience: int = 4,
                big_lr_epochs: int = -1,
                early_stopping: int = 16,
                start_epoch: int = 0,
                megan_warmup_epochs: int = 1,
                save_each_epoch: bool = False,
                max_n_epochs: int = 1000):
    """
    Train MEGAN model
    """
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    checkpoints_path = os.path.join(save_path, 'checkpoints')
    if save_each_epoch and not os.path.exists(checkpoints_path):
        os.makedirs(checkpoints_path)

    log_current_config()
    conf_path = os.path.join(save_path, 'config.gin')
    save_current_config(conf_path)

    model_path = os.path.join(save_path, 'model.pt')
    best_model_path = os.path.join(save_path, 'model_best.pt')

    summary_dir = 'summary'
    summary_dir = os.path.join(save_path, summary_dir)
    tf_callback = DumpTensorflowSummaries(
        save_path=summary_dir, step_multiplier=train_samples_per_epoch)

    dataset = get_dataset()
    featurizer = get_featurizer(featurizer_key)
    assert isinstance(featurizer, MeganTrainingSamplesFeaturizer)
    action_vocab = featurizer.get_actions_vocabulary(dataset.feat_dir)

    # copy featurizer dictionary files needed for using the model
    feat_dir = featurizer.dir(dataset.feat_dir)
    model_feat_dir = featurizer.dir(save_path)
    if not os.path.exists(model_feat_dir):
        os.makedirs(model_feat_dir)
    copyfile(get_actions_vocab_path(feat_dir),
             get_actions_vocab_path(model_feat_dir))
    copyfile(get_prop2oh_vocab_path(feat_dir),
             get_prop2oh_vocab_path(model_feat_dir))

    logger.info("Creating model...")
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = Megan(n_atom_actions=action_vocab['n_atom_actions'],
                  n_bond_actions=action_vocab['n_bond_actions'],
                  prop2oh=action_vocab['prop2oh']).to(device)
    summary(model)

    logger.info("Loading data...")
    data_dict = {}

    logger.info(f"Training for maximum of {max_n_epochs} epochs...")

    start_learning_rate = learning_rate
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def set_lr(lr: float):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def run_batch(ind: np.ndarray, train: bool) -> dict:
        if train:
            optimizer.zero_grad()

        batch_ind = np.random.choice(ind, size=batch_size, replace=False)
        batch_metrics = {}
        batch = generate_batch(batch_ind,
                               data_dict['metadata'],
                               featurizer,
                               data_dict['data'],
                               action_vocab=action_vocab)

        batch_result = model(batch)
        actions = batch_result['output']

        target, n_steps = batch['target'], batch['n_steps']
        n_total_steps = torch.sum(n_steps)

        y_max_pred_prob, y_pred = torch.max(actions, dim=-1)
        y_val, y_true = torch.max(target, dim=-1)
        y_val_one = y_val == 1
        is_hard = batch['is_hard']

        weight = torch.ones_like(is_hard)
        avg_weight = torch.mean(weight.float(), axis=-1)

        weight = weight * y_val_one
        weight = weight.unsqueeze(-1).expand(*actions.shape)
        target_one = target == 1
        eps = 1e-09

        loss = -torch.log2(actions + ~target_one + eps) * target_one * weight
        loss = torch.sum(loss, dim=-1)
        path_losses = torch.sum(loss, dim=-1) / (avg_weight * 16)

        min_losses = []
        # for each reaction, use the minimum loss for each possible path as the loss to optimize
        path_i = 0
        for n_paths in batch['n_paths']:
            path_loss = torch.min(path_losses[path_i:path_i + n_paths])
            min_losses.append(path_loss.unsqueeze(-1))
            path_i += n_paths
        min_losses = torch.cat(min_losses)

        loss = torch.mean(min_losses)

        if torch.isinf(loss):
            raise ValueError(
                'Infinite loss (correct action has predicted probability=0.0)')

        if loss != loss:  # this is only true for NaN in pytorch
            raise ValueError('NaN loss')

        # skip accuracy metrics if there are no positive samples in batch
        correct = ((y_pred == y_true) & y_val_one).float()

        step_correct = torch.sum(correct) / n_total_steps
        batch_metrics['step_acc'] = step_correct.cpu().detach().numpy()

        total_hard = torch.sum(is_hard)
        if total_hard > 0:
            hard_correct = torch.sum(correct * is_hard) / total_hard
            batch_metrics['step_acc_hard'] = hard_correct.cpu().detach().numpy(
            )

        is_easy = (1.0 - is_hard) * y_val_one

        total_easy = torch.sum(is_easy)
        if total_easy > 0:
            easy_correct = torch.sum(correct * is_easy) / total_easy
            batch_metrics['step_acc_easy'] = easy_correct.cpu().detach().numpy(
            )

        all_correct = torch.sum(correct, dim=-1)
        all_correct = all_correct == n_steps
        acc = []
        path_i = 0
        for n_paths in batch['n_paths']:
            corr = any(all_correct[i] == 1
                       for i in range(path_i, path_i + n_paths))
            acc.append(corr)
            path_i += n_paths
        if len(acc) > 0:
            batch_metrics['acc'] = np.mean(acc)

        if train:
            loss.backward()
            optimizer.step()

        batch_metrics['loss'] = loss.cpu().detach().numpy()
        return batch_metrics

    def get_lr():
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def run_epoch(set_key: str,
                  i_ep: int,
                  all_ind: np.ndarray,
                  train: bool,
                  batches_per_epoch: int,
                  lr_step: float = 0.0):
        torch.cuda.empty_cache()
        if train:
            model.train()
        else:
            model.eval()

        metrics = {}
        counts = Counter()

        for batch_i in tqdm(range(batches_per_epoch),
                            desc=f'{save_path} {set_key} epoch {i_ep + 1}'):
            if lr_step > 0:
                set_lr(get_lr() + lr_step)
            try:
                batch_metrics = run_batch(all_ind, train)
                for k, v in batch_metrics.items():
                    if k not in metrics:
                        metrics[k] = 0
                    metrics[k] += v
                    counts[k] += 1
            except AssertionError as e:
                # batch skipped because of zero loss
                logger.debug(f"Exception while running batch: {str(e)}")
            except Exception as e:
                logger.warning(f"Exception while running batch: {str(e)}")
                raise e

        metrics = dict((k, v / counts[k]) for k, v in metrics.items())
        str_metrics = ', '.join("{:s}={:.4f}".format(k, v)
                                for k, v in metrics.items())
        logger.info(f'{set_key} epoch {i_ep + 1}: {str_metrics}')

        if train:
            save_weights(model_path,
                         model,
                         optimizer,
                         epoch=i_ep,
                         lr=get_lr(),
                         no_progress=no_progress)

            if save_each_epoch:
                model_epoch_path = os.path.join(
                    checkpoints_path,
                    f'model_{(i_ep + 1) * train_samples_per_epoch}.pt')
                save_weights(model_epoch_path,
                             model,
                             optimizer,
                             epoch=i_ep,
                             lr=get_lr())
        return metrics

    best_acc = 0
    no_progress = 0

    if os.path.exists(model_path):
        checkpoint = load_state_dict(model_path)
        if 'epoch' in checkpoint:
            start_epoch = checkpoint['epoch'] + 1
        logger.info("Resuming training after {} epochs".format(start_epoch))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if 'lr' in checkpoint:
            learning_rate = checkpoint['lr']
            start_learning_rate = learning_rate
            logger.info(
                "Resuming training with LR={:f} epochs".format(learning_rate))
            set_lr(learning_rate)
        if 'valid_acc' in checkpoint:
            best_acc = checkpoint['valid_acc']
            logger.info(f"Best acc so far: {best_acc}")

    megan_warmup_epochs = max(megan_warmup_epochs - start_epoch, 0)
    if megan_warmup_epochs > 0:
        learning_rate = 0.0
        set_lr(learning_rate)

    no_progress = 0
    no_progress_lr = 0

    logger.info('Loading data')
    loaded_data = featurizer.load(dataset.feat_dir)
    chunk_metadata = loaded_data['reaction_metadata']
    data_dict['data'] = loaded_data
    data_dict['metadata'] = chunk_metadata
    data_dict['mean_n_steps'] = np.mean(data_dict['metadata']['n_samples'])

    metadata = data_dict['metadata']
    if 'remapped' in metadata:
        train_ind = (metadata['is_train'] == 1) & (metadata['remapped'])
        valid_ind = (metadata['is_train'] == 0) & (metadata['remapped'])
    else:
        train_ind = metadata['is_train'] == 1
        valid_ind = metadata['is_train'] == 0

    if 'path_i' in metadata:
        train_ind = train_ind & (metadata['path_i'] == 0)
        valid_ind = valid_ind & (metadata['path_i'] == 0)

    train_ind = np.argwhere(train_ind).flatten()
    valid_ind = np.argwhere(valid_ind).flatten()

    logger.info(
        f"Training on chunk of {len(train_ind)} training samples and {len(valid_ind)} valid samples"
    )
    if train_samples_per_epoch == -1:
        train_samples_per_epoch = len(train_ind)
    if valid_samples_per_epoch == -1:
        valid_samples_per_epoch = len(valid_ind)
    train_batches_per_epoch = int(np.ceil(train_samples_per_epoch /
                                          batch_size))
    valid_batches_per_epoch = int(np.ceil(valid_samples_per_epoch /
                                          batch_size))

    logger.info(
        f'Starting training on epoch {start_epoch + 1} with Learning Rate={learning_rate} '
        f'({megan_warmup_epochs} warmup epochs)')

    for epoch_i in range(start_epoch, max_n_epochs):
        if epoch_i == megan_warmup_epochs:
            set_lr(start_learning_rate)
            logger.info(
                f'Learning rate set to {start_learning_rate} after {megan_warmup_epochs} warmup epochs'
            )

        if big_lr_epochs != -1 and epoch_i == big_lr_epochs:
            learning_rate *= gen_lr_factor
            no_progress = 0
            no_progress_lr = 0
            set_lr(learning_rate)
            logger.info(f'Changing Learning Rate to {learning_rate}')

        if megan_warmup_epochs > 0:
            warmup_lr_step = start_learning_rate / (train_batches_per_epoch *
                                                    megan_warmup_epochs)
        else:
            warmup_lr_step = 0

        learning_rate = get_lr()
        train_metrics = run_epoch(
            'train',
            epoch_i,
            train_ind,
            True,
            train_batches_per_epoch,
            lr_step=warmup_lr_step if epoch_i < megan_warmup_epochs else 0.0)
        with torch.no_grad():
            valid_metrics = run_epoch('valid', epoch_i, valid_ind, False,
                                      valid_batches_per_epoch)

        all_metrics = {}
        for key, val in train_metrics.items():
            all_metrics[f'train_{key}'] = val
        for key, val in valid_metrics.items():
            all_metrics[f'valid_{key}'] = val

        all_metrics['lr'] = learning_rate
        tf_callback.on_epoch_end(epoch_i + 1, all_metrics)

        valid_acc = valid_metrics['acc']
        if valid_acc > best_acc:
            logger.info(
                f'Saving best model from epoch {epoch_i + 1} to {best_model_path} (acc={valid_acc})'
            )
            save_weights(best_model_path,
                         model,
                         optimizer,
                         epoch=epoch_i,
                         lr=learning_rate,
                         valid_acc=valid_acc)

            best_acc = valid_acc
            no_progress = 0
            no_progress_lr = 0
        else:
            no_progress += 1
            no_progress_lr += 1

        if big_lr_epochs == -1 or epoch_i >= big_lr_epochs:
            if no_progress_lr > gen_lr_patience:
                learning_rate *= gen_lr_factor
                logger.info(f'Changing Learning Rate to {learning_rate}')
                set_lr(learning_rate)
                no_progress_lr = 0

            if no_progress > early_stopping:
                logger.info(f'Early stopping after {epoch_i + 1} epochs')
                break

    logger.info("Experiment finished!")
Exemple #10
0
 def save_weights_fnc(logs=None):
     logger.info("Saving model from beginning")
     save_weights(model.model, model.optimizer,
                  os.path.join(save_path, "init_weights.pt"))