예제 #1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    random.seed(31241)
    np.random.seed(41982)
    tf.set_random_seed(1327634)

    color = True # Must change this and the dataset Flags to the correct path to use color
    if FLAGS.is_debug:
       reader = Bouncing_Balls_Data_Reader(FLAGS.dataset, FLAGS.batch_size, color=color, train_size=160*5, validation_size=8*5, test_size=8*5, num_partitions=5)
    else:
       reader = Bouncing_Balls_Data_Reader(FLAGS.dataset, FLAGS.batch_size, color=color)

    data_fn = lambda epoch, batch_index: reader.read_data(batch_index, reader.TRAIN)
    frame_shape = reader.read_data(0, reader.TRAIN).shape[2:]
    print("Frame shape: ", frame_shape)
    num_batches = reader.num_batches(reader.TRAIN)
    print("Num batches: %d" % num_batches)
    input_sequence_range = range(5, 16)
    print("Input sequence range min: %d, max: %d" % (min(input_sequence_range), max(input_sequence_range)))

    save_sample_fn = utils.gen_save_sample_fn(FLAGS.sample_dir, image_prefix="train")

    with tf.Session() as sess:
        pgn  = PGN(sess, FLAGS.dataset_name, FLAGS.epoch, num_batches, FLAGS.batch_size, input_sequence_range,
                 data_fn, frame_shape=frame_shape, save_sample_fn=save_sample_fn, checkpoint_dir=FLAGS.checkpoint_dir,
                 lambda_adv_loss= FLAGS.lambda_adv_loss)

        if FLAGS.is_train:
            pgn.train()
        else:
            print("Loading from: %s" %(FLAGS.checkpoint_dir,))
            if pgn.load(FLAGS.checkpoint_dir) :
               print(" [*] Successfully loaded")
            else:
               print(" [!] Load failed")

        if FLAGS.is_test:
           result = test.test(pgn, reader)
           result_str = pp.pformat(result)
           fid = open(os.path.join(FLAGS.sample_dir, 'test_out.txt'), mode='w')
           fid.write(unicode(result_str))
           fid.close()

        if FLAGS.is_visualize:
           for i in range(3):
               vid_seq = reader.read_data(i, data_set_type=reader.TEST, batch_size=1)[:, 0, :, :, :]
               utils.make_prediction_gif(pgn, os.path.join(FLAGS.sample_dir, 'vis_%d.gif' % i), video_sequence=vid_seq)
           utils.plot_convergence(pgn.get_MSE_history(), "MSE Convergence",
                        path=os.path.join(FLAGS.sample_dir, "vis_MSE_convergence.png"))
예제 #2
0
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """

    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = PGN(v)
    model.load_model()
    model.to(DEVICE)
    if config.fine_tune:
        # In fine-tuning mode, we fix the weights of all parameters except attention.wc.
        print('Fine-tuning mode.')
        for name, params in model.named_parameters():
            if name != 'attention.wc.weight':
                params.requires_grad = False
    # forward
    print("loading data")
    train_data = SampleDataset(dataset.pairs, v)
    val_data = SampleDataset(val_dataset.pairs, v)

    print("initializing optimizer")

    # Define the optimizer.
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_losses = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'rb') as f:
            val_losses = pickle.load(f)


#     torch.cuda.empty_cache()
# SummaryWriter: Log writer used for TensorboardX visualization.
    writer = SummaryWriter(config.log_path)
    # tqdm: A tool for drawing progress bars during training.
    # scheduled_sampler : A tool for choosing teacher_forcing or not
    num_epochs = len(range(start_epoch, config.epochs))
    scheduled_sampler = ScheduledSampler(num_epochs)
    if config.scheduled_sampling:
        print('scheduled_sampling mode.')
    #  teacher_forcing = True

    with tqdm(total=config.epochs) as epoch_progress:
        for epoch in range(start_epoch, config.epochs):
            print(config_info(config))
            batch_losses = []  # Get loss of each batch.
            num_batches = len(train_dataloader)
            # set a teacher_forcing signal
            if config.scheduled_sampling:
                teacher_forcing = scheduled_sampler.teacher_forcing(
                    epoch - start_epoch)
            else:
                teacher_forcing = True
            print('teacher_forcing = {}'.format(teacher_forcing))
            with tqdm(total=num_batches) as batch_progress:
                for batch, data in enumerate(tqdm(train_dataloader)):
                    x, y, x_len, y_len, oov, len_oovs = data
                    assert not np.any(np.isnan(x.numpy()))
                    if config.is_cuda:  # Training with GPUs.
                        x = x.to(DEVICE)
                        y = y.to(DEVICE)
                        x_len = x_len.to(DEVICE)
                        len_oovs = len_oovs.to(DEVICE)

                    model.train()  # Sets the module in training mode.
                    optimizer.zero_grad()  # Clear gradients.
                    # Calculate loss.  Call model forward propagation
                    loss = model(x,
                                 x_len,
                                 y,
                                 len_oovs,
                                 batch=batch,
                                 num_batches=num_batches,
                                 teacher_forcing=teacher_forcing)
                    batch_losses.append(loss.item())
                    loss.backward()  # Backpropagation.

                    # Do gradient clipping to prevent gradient explosion.
                    clip_grad_norm_(model.encoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.decoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.attention.parameters(),
                                    config.max_grad_norm)
                    optimizer.step()  # Update weights.

                    # Output and record epoch loss every 100 batches.
                    if (batch % 32) == 0:
                        batch_progress.set_description(f'Epoch {epoch}')
                        batch_progress.set_postfix(Batch=batch,
                                                   Loss=loss.item())
                        batch_progress.update()
                        # Write loss for tensorboard.
                        writer.add_scalar(f'Average loss for epoch {epoch}',
                                          np.mean(batch_losses),
                                          global_step=batch)
            # Calculate average loss over all batches in an epoch.
            epoch_loss = np.mean(batch_losses)

            epoch_progress.set_description(f'Epoch {epoch}')
            epoch_progress.set_postfix(Loss=epoch_loss)
            epoch_progress.update()

            avg_val_loss = evaluate(model, val_data, epoch)

            print('training loss:{}'.format(epoch_loss),
                  'validation loss:{}'.format(avg_val_loss))

            # Update minimum evaluating loss.
            if (avg_val_loss < val_losses):
                torch.save(model.encoder, config.encoder_save_name)
                torch.save(model.decoder, config.decoder_save_name)
                torch.save(model.attention, config.attention_save_name)
                torch.save(model.reduce_state, config.reduce_state_save_name)
                val_losses = avg_val_loss
            with open(config.losses_path, 'wb') as f:
                pickle.dump(val_losses, f)

    writer.close()
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """
    torch.autograd.set_detect_anomaly(True)
    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = PGN(v)
    model.load_model()
    model.to(DEVICE)
    if config.fine_tune:
        # In fine-tuning mode, we fix the weights of all parameters except attention.wc.
        logging.info('Fine-tuning mode.')
        for name, params in model.named_parameters():
            if name != 'attention.wc.weight':
                params.requires_grad = False
    # forward
    logging.info("loading data")
    train_data = dataset
    val_data = val_dataset

    logging.info("initializing optimizer")

    # Define the optimizer.
    #     optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    optimizer = optim.Adagrad(
        model.parameters(),
        lr=config.learning_rate,
        initial_accumulator_value=config.initial_accumulator_value)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.2)  # 学习率调整
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_loss = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'r') as f:
            val_loss = float(f.readlines()[-1].split("=")[-1])
            logging.info("the last best val loss is: " + str(val_loss))


#     torch.cuda.empty_cache()
# SummaryWriter: Log writer used for TensorboardX visualization.
    writer = SummaryWriter(config.log_path)
    # tqdm: A tool for drawing progress bars during training.
    early_stopping_count = 0

    logging.info("start training model {}, ".format(config.model_name) + \
        "epoch : {}, ".format(config.epochs) +
        "batch_size : {}, ".format(config.batch_size) +
        "num batches: {}, ".format(len(train_dataloader)))

    for epoch in range(start_epoch, config.epochs):
        batch_losses = []  # Get loss of each batch.
        num_batches = len(train_dataloader)
        #             with tqdm(total=num_batches//100) as batch_progress:
        for batch, data in enumerate(train_dataloader):
            x, y, x_len, y_len, oov, len_oovs, img_vec = data
            assert not np.any(np.isnan(x.numpy()))
            if config.is_cuda:  # Training with GPUs.
                x = x.to(DEVICE)
                y = y.to(DEVICE)
                x_len = x_len.to(DEVICE)
                len_oovs = len_oovs.to(DEVICE)
                img_vec = img_vec.to(DEVICE)
            if batch == 0:
                logging.info("x: %s, shape: %s" % (x, x.shape))
                logging.info("y: %s, shape: %s" % (y, y.shape))
                logging.info("oov: %s" % oov)
                logging.info("img_vec: %s, shape: %s" %
                             (img_vec, img_vec.shape))

            model.train()  # Sets the module in training mode.
            optimizer.zero_grad()  # Clear gradients.

            loss = model(x,
                         y,
                         len_oovs,
                         img_vec,
                         batch=batch,
                         num_batches=num_batches)
            batch_losses.append(loss.item())
            loss.backward()  # Backpropagation.

            # Do gradient clipping to prevent gradient explosion.
            clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.attention.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.reduce_state.parameters(),
                            config.max_grad_norm)
            optimizer.step()  # Update weights.
            #             scheduler.step()

            #                     # Output and record epoch loss every 100 batches.
            if (batch % 100) == 0:
                #                         batch_progress.set_description(f'Epoch {epoch}')
                #                         batch_progress.set_postfix(Batch=batch,
                #                                                    Loss=loss.item())
                #                         batch_progress.update()
                #                         # Write loss for tensorboard.
                writer.add_scalar(f'Average_loss_for_epoch_{epoch}',
                                  np.mean(batch_losses),
                                  global_step=batch)
                logging.info('epoch: {}, batch:{}, training loss:{}'.format(
                    epoch, batch, np.mean(batch_losses)))

        # Calculate average loss over all batches in an epoch.
        epoch_loss = np.mean(batch_losses)

        #             epoch_progress.set_description(f'Epoch {epoch}')
        #             epoch_progress.set_postfix(Loss=epoch_loss)
        #             epoch_progress.update()

        avg_val_loss = evaluate(model, val_data, epoch)

        logging.info('epoch: {} '.format(epoch) +
                     'training loss:{} '.format(epoch_loss) +
                     'validation loss:{} '.format(avg_val_loss))

        # Update minimum evaluating loss.
        if not os.path.exists(os.path.dirname(config.encoder_save_name)):
            os.mkdir(os.path.dirname(config.encoder_save_name))
        if (avg_val_loss < val_loss):
            logging.info("saving model to ../saved_model/ %s" %
                         config.model_name)
            torch.save(model.encoder, config.encoder_save_name)
            torch.save(model.decoder, config.decoder_save_name)
            torch.save(model.attention, config.attention_save_name)
            torch.save(model.reduce_state, config.reduce_state_save_name)
            val_loss = avg_val_loss
            with open(config.losses_path, 'a') as f:
                f.write(f"best val loss={val_loss}\n")
        else:
            early_stopping_count += 1
        if early_stopping_count >= config.patience:
            logging.info(
                f'Validation loss did not decrease for {config.patience} epochs, stop training.'
            )
            break

    writer.close()