Beispiel #1
0
def main(args):
    if args.randomize_checkpoint_path == 1:
        name, ext = os.path.splitext(args.checkpoint_path)
        num = random.randint(1, 1000000)
        args.checkpoint_path = '%s_%06d%s' % (name, num, ext)

    vocab = utils.load_vocab(args.vocab_json)

    if args.use_local_copies == 1:
        shutil.copy(args.train_question_h5, '/tmp/train_questions.h5')
        shutil.copy(args.train_features_h5, '/tmp/train_features.h5')
        shutil.copy(args.val_question_h5, '/tmp/val_questions.h5')
        shutil.copy(args.val_features_h5, '/tmp/val_features.h5')
        args.train_question_h5 = '/tmp/train_questions.h5'
        args.train_features_h5 = '/tmp/train_features.h5'
        args.val_question_h5 = '/tmp/val_questions.h5'
        args.val_features_h5 = '/tmp/val_features.h5'

    question_families = None
    if args.family_split_file is not None:
        with open(args.family_split_file, 'r') as f:
            question_families = json.load(f)

    train_loader_kwargs = {
        'question_h5': args.train_question_h5,
        'feature_h5': args.train_features_h5,
        'vocab': vocab,
        'batch_size': args.batch_size,
        'shuffle': args.shuffle_train_data == 1,
        'question_families': question_families,
        'max_samples': args.num_train_samples,
        'num_workers': args.loader_num_workers,
    }
    val_loader_kwargs = {
        'question_h5': args.val_question_h5,
        'feature_h5': args.val_features_h5,
        'vocab': vocab,
        'batch_size': args.batch_size,
        'question_families': question_families,
        'max_samples': args.num_val_samples,
        'num_workers': args.loader_num_workers,
    }

    with ClevrDataLoader(**train_loader_kwargs) as train_loader, \
         ClevrDataLoader(**val_loader_kwargs) as val_loader:
        train_loop(args, train_loader, val_loader)

    if args.use_local_copies == 1 and args.cleanup_local_copies == 1:
        os.remove('/tmp/train_questions.h5')
        os.remove('/tmp/train_features.h5')
        os.remove('/tmp/val_questions.h5')
        os.remove('/tmp/val_features.h5')
Beispiel #2
0
def main(args):
  print()
  model = None
  if args.baseline_model is not None:
    print('Loading baseline model from ', args.baseline_model)
    model, _ = utils.load_baseline(args.baseline_model)
    if args.vocab_json is not None:
      new_vocab = utils.load_vocab(args.vocab_json)
      model.rnn.expand_vocab(new_vocab['question_token_to_idx'])
  elif args.program_generator is not None and args.execution_engine is not None:
    print('Loading program generator from ', args.program_generator)
    program_generator, _ = utils.load_program_generator(args.program_generator)
    print('Loading execution engine from ', args.execution_engine)
    execution_engine, _ = utils.load_execution_engine(args.execution_engine, verbose=False)
    if args.vocab_json is not None:
      new_vocab = utils.load_vocab(args.vocab_json)
      program_generator.expand_encoder_vocab(new_vocab['question_token_to_idx'])
    model = (program_generator, execution_engine)
  else:
    print('Must give either --baseline_model or --program_generator and --execution_engine')
    return

  if args.question is not None and args.image is not None:
    run_single_example(args, model)
  else:
    vocab = load_vocab(args)
    loader_kwargs = {
      'question_h5': args.input_question_h5,
      'feature_h5': args.input_features_h5,
      'vocab': vocab,
      'batch_size': args.batch_size,
    }
    if args.num_samples is not None and args.num_samples > 0:
      loader_kwargs['max_samples'] = args.num_samples
    if args.family_split_file is not None:
      with open(args.family_split_file, 'r') as f:
        loader_kwargs['question_families'] = json.load(f)
    with ClevrDataLoader(**loader_kwargs) as loader:
      run_batch(args, model, loader)
def check_accuracy(args, program_generator, execution_engine, baseline_model,
                   loader):
    set_mode('eval', [program_generator, execution_engine, baseline_model])
    num_correct, num_samples = 0, 0
    for batch in loader:
        if num_samples % 30 == 0:
            print('process', num_samples, end='\r')
        refexps, _, feats, answers, programs, _, _ = batch

        refexps_var = Variable(refexps.cuda(), volatile=True)
        feats_var = Variable(feats.cuda(), volatile=True)
        answers_var = Variable(feats.cuda(), volatile=True)
        if programs[0] is not None:
            programs_var = Variable(programs.cuda(), volatile=True)

        scores = None  # Use this for everything but PG
        if args.model_type == 'PG':
            vocab = utils.load_vocab(args.vocab_json)
            for i in range(refexps.size(0)):
                program_pred = program_generator.sample(
                    Variable(refexps[i:i + 1].cuda(), volatile=True))
                program_pred_str = iep.preprocess.decode(
                    program_pred, vocab['program_idx_to_token'])
                program_str = iep.preprocess.decode(
                    programs[i], vocab['program_idx_to_token'])
                if program_pred_str == program_str:
                    num_correct += 1
                num_samples += 1
        elif args.model_type == 'EE':
            scores = execution_engine(feats_var, programs_var)
            scores = None
        elif args.model_type == 'PG+EE':
            programs_pred = program_generator.reinforce_sample(refexps_var,
                                                               argmax=True)
            scores = execution_engine(feats_var, programs_pred)
        elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
            scores = baseline_model(refexps_var, feats_var)

        if scores is not None:
            _, preds = scores.data.cpu().max(1)
            num_correct += (preds == answers).sum()
            num_samples += preds.size(0)

        if num_samples >= args.num_val_samples:
            break

    set_mode('train', [program_generator, execution_engine, baseline_model])
    acc = float(num_correct) / (num_samples + 0.000001)
    return acc
Beispiel #4
0
def check_accuracy(args, program_generator, execution_engine, baseline_model, loader):
  set_mode('eval', [program_generator, execution_engine, baseline_model])
  num_correct, num_samples = 0, 0
  for batch in tqdm(loader):
    questions, images, feats, answers, programs, _, ocr_tokens = batch
    with torch.no_grad():
      questions_var = Variable(questions.cuda())
      if feats[0] is not None:
        feats_var = Variable(feats.cuda())
      answers_var = Variable(answers.cuda())
      if programs[0] is not None:
        programs_var = Variable(programs.cuda())
    scores = None # Use this for everything but PG
    if args.model_type == 'PG':
      vocab = utils.load_vocab(args.vocab_json)
      for i in range(questions.size(0)):
        with torch.no_grad():
          program_pred = program_generator.sample(Variable(questions[i:i+1].cuda()))
        program_pred_str = iep.preprocess.decode(program_pred, vocab['program_idx_to_token'])
        program_str = iep.preprocess.decode(programs[i].tolist(), vocab['program_idx_to_token'])
        if program_pred_str == program_str:
          num_correct += 1
        num_samples += 1
    elif args.model_type == 'EE':
      text_embs = process_tokens(ocr_tokens).cuda()
      scores = execution_engine(feats_var, programs_var, text_embs)
    elif args.model_type == 'PG+EE':
      programs_pred = program_generator.reinforce_sample(
                          questions_var, argmax=True)
      scores = execution_engine(feats_var, programs_pred)
    elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
      scores = baseline_model(questions_var, feats_var)
    elif args.model_type == 'PG+EE+GQNT':
      programs_pred = program_generator.reinforce_sample(questions_var)
      text_embs = process_tokens(ocr_tokens)
      scores = execution_engine(feats_var, programs_pred, text_embs)

    if scores is not None:
      _, preds = scores.data.cpu().max(1)
      num_correct += (preds == answers).sum()
      num_samples += preds.size(0)

    if num_samples >= args.num_val_samples:
      break

  set_mode('train', [program_generator, execution_engine, baseline_model])
  acc = float(num_correct) / num_samples
  return acc
Beispiel #5
0
def get_execution_engine(args):
    vocab = utils.load_vocab(args.vocab_json)
    kwargs = {
        'vocab': vocab,
        'feature_dim': parse_int_list(args.feature_dim),
        'stem_batchnorm': args.module_stem_batchnorm == 1,
        'stem_num_layers': args.module_stem_num_layers,
        'module_dim': args.module_dim,
        'module_residual': args.module_residual == 1,
        'module_batchnorm': args.module_batchnorm == 1,
        'classifier_proj_dim': args.classifier_proj_dim,
        'classifier_downsample': args.classifier_downsample,
        'classifier_fc_layers': parse_int_list(args.classifier_fc_dims),
        'classifier_batchnorm': args.classifier_batchnorm == 1,
        'classifier_dropout': args.classifier_dropout,
    }
    ee = ModuleNet(**kwargs)
    return ee, kwargs
Beispiel #6
0
def get_program_generator(args):
  vocab = utils.load_vocab(args.vocab_json)
  if args.program_generator_start_from is not None:
    pg, kwargs = utils.load_program_generator(args.program_generator_start_from)
    cur_vocab_size = pg.encoder_embed.weight.size(0)
    if cur_vocab_size != len(vocab['refexp_token_to_idx']):
      print('Expanding vocabulary of program generator')
      pg.expand_encoder_vocab(vocab['refexp_token_to_idx'])
      kwargs['encoder_vocab_size'] = len(vocab['refexp_token_to_idx'])
  else:
    kwargs = {
      'encoder_vocab_size': len(vocab['refexp_token_to_idx']),
      'decoder_vocab_size': len(vocab['program_token_to_idx']),
      'wordvec_dim': args.rnn_wordvec_dim,
      'hidden_dim': args.rnn_hidden_dim,
      'rnn_num_layers': args.rnn_num_layers,
      'rnn_dropout': args.rnn_dropout,
    }
    pg = Seq2Seq(**kwargs)
  pg.cuda()
  pg.train()
  return pg, kwargs
Beispiel #7
0
def get_execution_engine(args):
  vocab = utils.load_vocab(args.vocab_json)
  if args.execution_engine_start_from is not None:
    ee, kwargs = utils.load_execution_engine(args.execution_engine_start_from)
    # TODO: Adjust vocab?
  else:
    kwargs = {
      'vocab': vocab,
      'feature_dim': parse_int_list(args.feature_dim),
      'stem_batchnorm': args.module_stem_batchnorm == 1,
      'stem_num_layers': args.module_stem_num_layers,
      'module_dim': args.module_dim,
      'module_residual': args.module_residual == 1,
      'module_batchnorm': args.module_batchnorm == 1,
      'classifier_proj_dim': args.classifier_proj_dim,
      'classifier_downsample': args.classifier_downsample,
      'classifier_fc_layers': parse_int_list(args.classifier_fc_dims),
      'classifier_batchnorm': args.classifier_batchnorm == 1,
      'classifier_dropout': args.classifier_dropout,
    }
    ee = ModuleNet(**kwargs)
  ee.cuda()
  ee.train()
  return ee, kwargs
Beispiel #8
0
def train_loop(args, train_loader, val_loader):
    vocab = utils.load_vocab(args.vocab_json)
    program_generator, pg_kwargs, pg_optimizer = None, None, None
    execution_engine, ee_kwargs, ee_optimizer = None, None, None
    baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
    baseline_type = None

    pg_best_state, ee_best_state, baseline_best_state = None, None, None

    # Set up model
    if args.model_type == 'PG' or args.model_type == 'PG+EE':
        program_generator, pg_kwargs = get_program_generator(args)
        pg_optimizer = torch.optim.Adam(program_generator.parameters(),
                                        lr=args.learning_rate)
        print('Here is the program generator:')
        print(program_generator)
    if args.model_type == 'EE' or args.model_type == 'PG+EE':
        execution_engine, ee_kwargs = get_execution_engine(args)
        ee_optimizer = torch.optim.Adam(execution_engine.parameters(),
                                        lr=args.learning_rate)
        print('Here is the execution engine:')
        print(execution_engine)
    if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
        baseline_model, baseline_kwargs = get_baseline_model(args)
        params = baseline_model.parameters()
        if args.baseline_train_only_rnn == 1:
            params = baseline_model.rnn.parameters()
        baseline_optimizer = torch.optim.Adam(params, lr=args.learning_rate)
        print('Here is the baseline model')
        print(baseline_model)
        baseline_type = args.model_type
    loss_fn = torch.nn.CrossEntropyLoss().cuda()

    stats = {
        'train_losses': [],
        'train_rewards': [],
        'train_losses_ts': [],
        'train_accs': [],
        'val_accs': [],
        'val_accs_ts': [],
        'best_val_acc': -1,
        'model_t': 0,
    }
    t, epoch, reward_moving_average = 0, 0, 0

    set_mode('train', [program_generator, execution_engine, baseline_model])

    print('train_loader has %d samples' % len(train_loader.dataset))
    print('val_loader has %d samples' % len(val_loader.dataset))

    while t < args.num_iterations:
        epoch += 1
        print('Starting epoch %d' % epoch)
        for batch in train_loader:
            t += 1
            questions, _, feats, answers, programs, _ = batch
            questions_var = Variable(questions.cuda())
            feats_var = Variable(feats.cuda())
            answers_var = Variable(answers.cuda())
            if programs[0] is not None:
                programs_var = Variable(programs.cuda())

            reward = None
            if args.model_type == 'PG':
                # Train program generator with ground-truth programs
                pg_optimizer.zero_grad()
                loss = program_generator(questions_var, programs_var)
                loss.backward()
                pg_optimizer.step()
            elif args.model_type == 'EE':
                # Train execution engine with ground-truth programs
                ee_optimizer.zero_grad()
                scores = execution_engine(feats_var, programs_var)
                loss = loss_fn(scores, answers_var)
                loss.backward()
                ee_optimizer.step()
            elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
                baseline_optimizer.zero_grad()
                baseline_model.zero_grad()
                scores = baseline_model(questions_var, feats_var)
                loss = loss_fn(scores, answers_var)
                loss.backward()
                baseline_optimizer.step()
            elif args.model_type == 'PG+EE':
                programs_pred = program_generator.reinforce_sample(
                    questions_var)
                scores = execution_engine(feats_var, programs_pred)

                loss = loss_fn(scores, answers_var)
                _, preds = scores.data.cpu().max(1)
                raw_reward = (preds == answers).float()
                reward_moving_average *= args.reward_decay
                reward_moving_average += (
                    1.0 - args.reward_decay) * raw_reward.mean()
                centered_reward = raw_reward - reward_moving_average

                if args.train_execution_engine == 1:
                    ee_optimizer.zero_grad()
                    loss.backward()
                    ee_optimizer.step()

                if args.train_program_generator == 1:
                    pg_optimizer.zero_grad()
                    program_generator.reinforce_backward(
                        centered_reward.cuda())
                    pg_optimizer.step()

            if t % args.record_loss_every == 0:
                print(t, loss.data[0])
                stats['train_losses'].append(loss.data[0])
                stats['train_losses_ts'].append(t)
                if reward is not None:
                    stats['train_rewards'].append(reward)

            if t % args.checkpoint_every == 0:
                print('Checking training accuracy ... ')
                train_acc = check_accuracy(args, program_generator,
                                           execution_engine, baseline_model,
                                           train_loader)
                print('train accuracy is', train_acc)
                print('Checking validation accuracy ...')
                val_acc = check_accuracy(args, program_generator,
                                         execution_engine, baseline_model,
                                         val_loader)
                print('val accuracy is ', val_acc)
                stats['train_accs'].append(train_acc)
                stats['val_accs'].append(val_acc)
                stats['val_accs_ts'].append(t)

                if val_acc > stats['best_val_acc']:
                    stats['best_val_acc'] = val_acc
                    stats['model_t'] = t
                    best_pg_state = get_state(program_generator)
                    best_ee_state = get_state(execution_engine)
                    best_baseline_state = get_state(baseline_model)

                checkpoint = {
                    'args': args.__dict__,
                    'program_generator_kwargs': pg_kwargs,
                    'program_generator_state': best_pg_state,
                    'execution_engine_kwargs': ee_kwargs,
                    'execution_engine_state': best_ee_state,
                    'baseline_kwargs': baseline_kwargs,
                    'baseline_state': best_baseline_state,
                    'baseline_type': baseline_type,
                    'vocab': vocab
                }
                for k, v in stats.items():
                    checkpoint[k] = v
                print('Saving checkpoint to %s' % args.checkpoint_path)
                torch.save(checkpoint, args.checkpoint_path)
                del checkpoint['program_generator_state']
                del checkpoint['execution_engine_state']
                del checkpoint['baseline_state']
                with open(args.checkpoint_path + '.json', 'w') as f:
                    json.dump(checkpoint, f)

            if t == args.num_iterations:
                break
Beispiel #9
0
def train_loop(args, train_loader, val_loader):
    vocab = utils.load_vocab(args.vocab_json)
    program_generator, pg_kwargs, pg_optimizer = None, None, None
    execution_engine, ee_kwargs, ee_optimizer = None, None, None
    baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
    baseline_type = None

    pg_best_state, ee_best_state, baseline_best_state = None, None, None

    checkpoint = ModelCheckpoint(args.checkpoint_path,
                                 monitor='loss',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='min',
                                 load_weights_on_restart=True)
    # Set up model
    if args.model_type == 'PG' or args.model_type == 'PG+EE':
        program_generator, pg_kwargs = get_program_generator(args)
        pg_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the program generator:')
        # program_generator.build(input_shape=[46,])
        # program_generator.compile(optimizer='adam', loss='mse')
        # i print(program_generator.summary())
    if args.model_type == 'EE' or args.model_type == 'PG+EE':
        execution_engine, ee_kwargs = get_execution_engine(args)
        ee_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the execution engine:')
        print(execution_engine)

    stats = {
        'train_losses': [],
        'train_rewards': [],
        'train_losses_ts': [],
        'train_accs': [],
        'val_accs': [],
        'val_accs_ts': [],
        'best_val_acc': -1,
        'model_t': 0,
    }
    t, epoch, reward_moving_average = 0, 0, 0
    batch_size = 64
    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer,
                                     program_generator=program_generator)
    # set_mode('train', [program_generator, execution_engine, baseline_model])

    print('train_loader has %d samples' % len(train_loader))
    # train_loader = train_loader[:256]
    print('train_loader has %d samples' % len(train_loader))
    print('val_loader has %d samples' % len(val_loader))
    # data_sampler = iter(range(len(train_loader)))
    data_load = batch_creater(train_loader, batch_size, False)
    print("Data load length :", len(data_load))
    # print(data_load[0][0])

    while t < args.num_iterations:
        total_loss = 0
        epoch += 1
        print('Starting epoch %d' % epoch)
        print("value of t :", t)
        # train_loader_data = get_data(train_loader)
        # print("train data loader length :", len(train_loader_data))
        # print(train_loader[0].shape)
        # print(train_loader[0])
        for run_num, batch in enumerate(data_load):
            batch_loss = 0
            with tf.GradientTape() as tape:
                t += 1
                questions, _, feats, answers, programs, _ = to_tensor(
                    batch[0]), batch[1], to_tensor(batch[2]), to_tensor(
                        batch[3]), to_tensor(batch[4]), batch[5]

                #print("Questions : ", questions.shape)
                #print("Features :", feats.shape)
                #print(" Answers : ", answers.shape)
                #print(" prgrams : ", programs.shape)
                print("----------------")

                questions_var = tf.Variable(questions)
                feats_var = tf.Variable(feats)
                answers_var = tf.Variable(answers)
                if programs[0] is not None:
                    programs_var = tf.Variable(programs)

                reward = None
                if args.model_type == 'PG':
                    #checkpoint_dir = './training_checkpoints'
                    #checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
                    # checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer,
                    #                                 program_generator=program_generator)
                    # Train program generator with ground-truth programs+++
                    batch_loss = program_generator(questions_var, programs_var)
            total_loss += batch_loss
            variables = program_generator.variables
            gradients = tape.gradient(batch_loss, variables)
            pg_optimizer.apply_gradients(zip(gradients), variables)

            print('Epoch {} Batch No. {} Loss {:.4f}'.format(
                epoch, run_num, batch_loss.numpy()))
        if epoch % 2 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        if t == args.num_iterations:
            break
            # program_generator.compile(optimizer=pg_optimizer, loss=loss)
            # ques = np.asarray(questions_var.read_value())
            # prog = np.asarray(programs_var.read_value())
            # history = program_generator.fit(
            #     x=ques,
            #     y=prog,
            #     batch_size=args.batch_size,
            #     epochs=10,
            #     verbose=0,
            #     callbacks=[LossAndErrorPrintingCallback(), checkpoint])

            # elif args.model_type == 'EE':
            #     # Train execution engine with ground-truth programs
            #     scores = execution_engine(feats_var, programs_var)
            #     loss = tf.nn.softmax_cross_entropy_with_logits(
            #         scores, answers_var)
            #     execution_engine.compile(optimizer=ee_optimizer, loss=loss)
            #     history = execution_engine.fit(
            #         questions_var,
            #         to_categorical(answers_var),
            #         batch_size=args.batch_size,
            #         epochs=10,
            #         verbose=0,
            #         callbacks=[LossAndErrorPrintingCallback(), checkpoint])

            # elif args.model_type == 'PG+EE':
            #     programs_pred = program_generator.reinforce_sample(questions_var)
            #     scores = execution_engine(feats_var, programs_pred)
            #
            #     loss = tf.nn.softmax_cross_entropy_with_logits(scores, answers_var)
            #     _, preds = scores.data.max(1)
            #     # raw_reward = (preds == answers).float()
            #     raw_reward = tf.cast((preds == answers), dtype=tf.float32)
            #     reward_moving_average *= args.reward_decay
            #     reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean()
            #     centered_reward = raw_reward - reward_moving_average
            #
            #     if args.train_execution_engine == 1:
            #         ee_optimizer.zero_grad()
            #         loss.backward()
            #         ee_optimizer.step()
            #
            #     if args.train_program_generator == 1:
            #         pg_optimizer.zero_grad()
            #         program_generator.reinforce_backward(centered_reward.cuda())
            #         pg_optimizer.step()

            # if t % args.record_loss_every == 0:
            #     print(t, loss.data[0])
            #     stats['train_losses'].append(loss.data[0])
            #     stats['train_losses_ts'].append(t)
            #     if reward is not None:
            #         stats['train_rewards'].append(reward)
            #
            # if t % args.checkpoint_every == 0:
            #     print('Checking training accuracy ... ')
            #     train_acc = check_accuracy(args, program_generator, execution_engine,
            #                                baseline_model, train_loader)
            #     print('train accuracy is', train_acc)
            #     print('Checking validation accuracy ...')
            #     val_acc = check_accuracy(args, program_generator, execution_engine,
            #                              baseline_model, val_loader)
            #     print('val accuracy is ', val_acc)
            #     stats['train_accs'].append(train_acc)
            #     stats['val_accs'].append(val_acc)
            #     stats['val_accs_ts'].append(t)
            #
            #     if val_acc > stats['best_val_acc']:
            #         stats['best_val_acc'] = val_acc
            #         stats['model_t'] = t
            #         best_pg_state = get_state(program_generator)
            #         best_ee_state = get_state(execution_engine)
            #         best_baseline_state = get_state(baseline_model)
            #
            #     checkpoint = {
            #         'args': args.__dict__,
            #         'program_generator_kwargs': pg_kwargs,
            #         'program_generator_state': best_pg_state,
            #         'execution_engine_kwargs': ee_kwargs,
            #         'execution_engine_state': best_ee_state,
            #         'baseline_kwargs': baseline_kwargs,
            #         'baseline_state': best_baseline_state,
            #         'baseline_type': baseline_type,
            #         'vocab': vocab
            #     }
            #     for k, v in stats.items():
            #         checkpoint[k] = v
            #     print('Saving checkpoint to %s' % args.checkpoint_path)
            #     torch.save(checkpoint, args.checkpoint_path)
            #     del checkpoint['program_generator_state']
            #     del checkpoint['execution_engine_state']
            #     del checkpoint['baseline_state']
            #     with open(args.checkpoint_path + '.json', 'w') as f:
            #         json.dump(checkpoint, f)

        if t == args.num_iterations:
            break
Beispiel #10
0
def train_loop(args, train_loader, train_len, val_loader, val_len):
  vocab = utils.load_vocab(args.vocab_json)
  program_generator, pg_kwargs, pg_optimizer = None, None, None
  execution_engine, ee_kwargs, ee_optimizer = None, None, None
  baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
  baseline_type = None

  pg_best_state, ee_best_state, baseline_best_state = None, None, None

  # Set up model
  if args.model_type == 'PG' or args.model_type == 'PG+EE':
    program_generator, pg_kwargs = get_program_generator(args)
    pg_optimizer = torch.optim.Adam(program_generator.parameters(),
                                    lr=args.learning_rate)
    print('Here is the program generator:')
    print(program_generator)
  if args.model_type == 'EE' or args.model_type == 'PG+EE':
    execution_engine, ee_kwargs = get_execution_engine(args)
    ee_optimizer = torch.optim.Adam(execution_engine.parameters(),
                                    lr=args.learning_rate)
    print('Here is the execution engine:')
    print(execution_engine)
  if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
    baseline_model, baseline_kwargs = get_baseline_model(args)
    params = baseline_model.parameters()
    if args.baseline_train_only_rnn == 1:
      params = baseline_model.rnn.parameters()
    baseline_optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    print('Here is the baseline model')
    print(baseline_model)
    baseline_type = args.model_type
  loss_fn = torch.nn.CrossEntropyLoss().cuda()
  L1loss_fn = torch.nn.L1Loss().cuda()

  stats = {
    'train_losses': [], 'train_rewards': [], 'train_losses_ts': [],
    'train_accs': [], 'val_accs': [], 'val_accs_ts': [],
    'best_val_acc': -1, 'model_t': 0,
  }
  t, epoch, reward_moving_average = 0, 0, 0

  set_mode('train', [program_generator, execution_engine, baseline_model])

  print('train_loader has %d samples' % train_len)
  print('val_loader has %d samples' % val_len)

  tic_time = time.time()
  toc_time = time.time()
  while t < args.num_iterations:
    epoch += 1
    print('Starting epoch %d' % epoch)
    train_loader.reset()
    val_loader.reset()

    cum_I=0 ; cum_U=0


    for batch in train_loader:
      t += 1
      refexps, _, feats, answers, programs, __, image_id = batch

      refexps_var = Variable(refexps.cuda())
      feats_var = Variable(feats.cuda())
      answers_var = Variable(answers.cuda())
      if len(answers_var.shape) == 3:
        answers_var = answers_var.view(answers_var.shape[0], 1, answers_var.shape[1], answers_var.shape[2])
      if programs[0] is not None:
        programs_var = Variable(programs.cuda())

      reward = None
      

      if args.model_type == 'PG':
        # Train program generator with ground-truth programs
        pg_optimizer.zero_grad()
        loss = program_generator(refexps_var, programs_var)
        loss.backward()
        pg_optimizer.step()
      elif args.model_type == 'EE':
        # Train execution engine with ground-truth programs
        ee_optimizer.zero_grad()
        scores = execution_engine(feats_var, programs_var)
        preds = scores.clone()

        scores = scores.transpose(1,2).transpose(2,3).contiguous()
        scores = scores.view([-1,2]).cuda()
        _ans = answers_var.view([-1]).cuda()
        loss = loss_fn(scores, _ans)
        loss.backward()
        ee_optimizer.step()

        def compute_mask_IU(masks, target):
          assert(target.shape[-2:] == masks.shape[-2:])
          masks = masks.data.cpu().numpy()
          masks = masks[:, 1, :, :] > masks[:, 0, :, :]
          masks = masks.reshape([args.batch_size, 320, 320])
          target = target.data.cpu().numpy()
          print('np.sum(masks)={}'.format(np.sum(masks)))
          print('np.sum(target)={}'.format(np.sum(target)))
          I = np.sum(np.logical_and(masks, target))
          U = np.sum(np.logical_or(masks, target))
          return I, U

        I, U = compute_mask_IU(preds, answers)
        now_iou = I*1.0/U
        cum_I += I; cum_U += U
        cum_iou = cum_I*1.0/cum_U

        print_each = 10
        if t % print_each == 0:
          msg = 'now IoU = %f' % (now_iou)
          print(msg)
          msg = 'cumulative IoU = %f' % (cum_iou)
          print(msg)
        if t % print_each == 0:
          cur_time = time.time()
          since_last_print =  cur_time - toc_time
          toc_time = cur_time
          ellapsedtime = toc_time - tic_time
          iter_avr = since_last_print / (print_each+1e-5)
          batch_size = args.batch_size
          case_per_sec = print_each * 1 * batch_size / (since_last_print + 1e-6)
          estimatedleft = (args.num_iterations - t) * 1.0 * iter_avr
          print('ellapsedtime = %d, iter_avr = %f, case_per_sec = %f, estimatedleft = %f'
                % (ellapsedtime, iter_avr, case_per_sec, estimatedleft))

      elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
        baseline_optimizer.zero_grad()
        baseline_model.zero_grad()
        scores = baseline_model(refexps_var, feats_var)
        loss = loss_fn(scores, answers_var)
        loss.backward()
        baseline_optimizer.step()
      elif args.model_type == 'PG+EE':
        programs_pred = program_generator.reinforce_sample(refexps_var)
        programs_pred = programs_pred.data.cpu().numpy()
        programs_pred = torch.LongTensor(programs_pred).cuda()

        scores = execution_engine(feats_var, programs_pred)

        preds = scores.clone()

        scores = scores.transpose(1,2).transpose(2,3).contiguous()
        scores = scores.view([-1,2]).cuda()
        _ans = answers_var.view([-1]).cuda()
        loss = loss_fn(scores, _ans)

        def compute_mask_IU(masks, target):
          assert(target.shape[-2:] == masks.shape[-2:])
          masks = masks.data.cpu().numpy()
          masks = masks[:, 1, :, :] > masks[:, 0, :, :]
          masks = masks.reshape([args.batch_size, 320, 320])
          target = target.data.cpu().numpy()
          print('np.sum(masks)={}'.format(np.sum(masks)))
          print('np.sum(target)={}'.format(np.sum(target)))
          I = np.sum(np.logical_and(masks, target))
          U = np.sum(np.logical_or(masks, target))
          return I, U

        I, U = compute_mask_IU(preds, answers)
        now_iou = I*1.0/U
        cum_I += I; cum_U += U
        cum_iou = cum_I*1.0/cum_U

        print_each = 10
        if t % print_each == 0:
          msg = 'now IoU = %f' % (now_iou); print(msg)
          msg = 'cumulative IoU = %f' % (cum_iou); print(msg)
        if t % print_each == 0:
          cur_time = time.time()
          since_last_print =  cur_time - toc_time
          toc_time = cur_time
          ellapsedtime = toc_time - tic_time
          iter_avr = since_last_print / (print_each+1e-5)
          batch_size = args.batch_size
          case_per_sec = print_each * 1 * batch_size / (since_last_print + 1e-6)
          estimatedleft = (args.num_iterations - t) * 1.0 * iter_avr
          print('ellapsedtime = %d, iter_avr = %f, case_per_sec = %f, estimatedleft = %f'
                % (ellapsedtime, iter_avr, case_per_sec, estimatedleft))

        def easy_compute_mask_IU(masks, target):
          assert(target.shape[-2:] == masks.shape[-2:])
          masks = masks.data.cpu().numpy()
          masks = masks[1, :, :] > masks[0, :, :]
          masks = masks.reshape([320, 320])
          target = target.data.cpu().numpy()
          assert(target.shape == masks.shape)
          I = np.sum(np.logical_and(masks, target))
          U = np.sum(np.logical_or(masks, target))
          return I, U

        now_ious = []
        for _pred, _answer in zip(preds, answers):
          _I, _U = easy_compute_mask_IU(_pred, _answer)
          if _U > 0:
            now_ious.append(_I*1.0/_U)
          else:
            now_ious.append(0.0)

        raw_reward = torch.FloatTensor(now_ious)
        reward_moving_average *= args.reward_decay
        reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean()
        centered_reward = raw_reward - reward_moving_average

        if args.train_execution_engine == 1:
          ee_optimizer.zero_grad()
          loss.backward()
          ee_optimizer.step()

        if args.train_program_generator == 1:
          pg_optimizer.zero_grad()
          program_generator.reinforce_backward(centered_reward.cuda())
          pg_optimizer.step()

      if t % args.record_loss_every == 0:
        print(t, loss.data[0])
        stats['train_losses'].append(loss.data[0])
        stats['train_losses_ts'].append(t)
        if reward is not None:
          stats['train_rewards'].append(reward)

      if t % args.checkpoint_every == 0:
        print('Checking training accuracy ... ')
        if args.model_type == 'PG':
          train_acc = check_accuracy(args, program_generator, execution_engine,
                                     baseline_model, train_loader)
        else:
          train_acc = 0.0

        print('train accuracy is', train_acc)
        print('Checking validation accuracy ...')
        if args.model_type == 'PG':
          val_acc = check_accuracy(args, program_generator, execution_engine,
                                 baseline_model, val_loader)
        else:
          val_acc = 0.0

        print('val accuracy is ', val_acc)

        stats['train_accs'].append(train_acc)
        stats['val_accs'].append(val_acc)
        stats['val_accs_ts'].append(t)

        #Alwayse save models
        if True:
          stats['best_val_acc'] = val_acc
          stats['model_t'] = t
          best_pg_state = get_state(program_generator)
          best_ee_state = get_state(execution_engine)
          best_baseline_state = get_state(baseline_model)

        checkpoint = {
          'args': args.__dict__,
          'program_generator_kwargs': pg_kwargs,
          'program_generator_state': best_pg_state,
          'execution_engine_kwargs': ee_kwargs,
          'execution_engine_state': best_ee_state,
          'baseline_kwargs': baseline_kwargs,
          'baseline_state': best_baseline_state,
          'baseline_type': baseline_type,
          'vocab': vocab
        }
        for k, v in stats.items():
          checkpoint[k] = v
        print('Saving checkpoint to %s' % args.checkpoint_path + '_' + str(t))
        torch.save(checkpoint, args.checkpoint_path + '_' + str(t))
        del checkpoint['program_generator_state']
        del checkpoint['execution_engine_state']
        del checkpoint['baseline_state']

      if t == args.num_iterations:
        break
Beispiel #11
0
def main(args):
  if args.randomize_checkpoint_path == 1:
    name, ext = os.path.splitext(args.checkpoint_path)
    num = random.randint(1, 1000000)
    args.checkpoint_path = '%s_%06d%s' % (name, num, ext)

  vocab = utils.load_vocab(args.vocab_json)

  if args.use_local_copies == 1:
    shutil.copy(args.train_refexp_h5, '/tmp/train_refexps.h5')
    shutil.copy(args.train_features_h5, '/tmp/train_features.h5')
    shutil.copy(args.val_refexp_h5, '/tmp/val_refexps.h5')
    shutil.copy(args.val_features_h5, '/tmp/val_features.h5')
    args.train_refexp_h5 = '/tmp/train_refexps.h5'
    args.train_features_h5 = '/tmp/train_features.h5'
    args.val_refexp_h5 = '/tmp/val_refexps.h5'
    args.val_features_h5 = '/tmp/val_features.h5'

  refexp_families = None
  if args.family_split_file is not None:
    with open(args.family_split_file, 'r') as f:
      refexp_families = json.load(f)

  train_loader_kwargs = {
    'refexp_h5': args.train_refexp_h5,
    'feature_h5': args.train_features_h5,
    'vocab': vocab,
    'batch_size': args.batch_size,
    'shuffle': args.shuffle_train_data == 1,
    'refexp_families': refexp_families,
    'max_samples': args.num_train_samples,
    'num_workers': args.loader_num_workers,
  }
  val_loader_kwargs = {
    'refexp_h5': args.val_refexp_h5,
    'feature_h5': args.val_features_h5,
    'vocab': vocab,
    'batch_size': args.batch_size,
    'refexp_families': refexp_families,
    'max_samples': args.num_val_samples,
    'num_workers': args.loader_num_workers,
  }

  class TLoader:  
    def __init__(self, kwargs, batch_size):
      import copy
      self.kwargs = copy.deepcopy(kwargs)
      self.batch_size = batch_size 
      #self.reset()

    def reset(self):
      import copy
      self.loader = self.get_loader(copy.deepcopy(self.kwargs), self.batch_size)

    def __iter__(self):
      return self.loader

    def __next__(self):
      #yield self.loader
      assert 1==0
      pass
    
    def get_dataset(self, kwargs):
      if 'refexp_h5' not in kwargs:
        raise ValueError('Must give refexp_h5')
      if 'feature_h5' not in kwargs:
        raise ValueError('Must give feature_h5')
      if 'vocab' not in kwargs:
        raise ValueError('Must give vocab')

      feature_h5_path = kwargs.pop('feature_h5')
      print('Reading features from ', feature_h5_path)
      _feature_h5 = h5py.File(feature_h5_path, 'r')

      _image_h5 = None
      if 'image_h5' in kwargs:
        image_h5_path = kwargs.pop('image_h5')
        print('Reading images from ', image_h5_path)
        _image_h5 = h5py.File(image_h5_path, 'r')

      vocab = kwargs.pop('vocab')
      mode = kwargs.pop('mode', 'prefix')

      refexp_families = kwargs.pop('refexp_families', None)
      max_samples = kwargs.pop('max_samples', None)
      refexp_h5_path = kwargs.pop('refexp_h5')
      image_idx_start_from = kwargs.pop('image_idx_start_from', None)
      print('Reading refexps from ', refexp_h5_path)
      _dataset = ClevrDataset(refexp_h5_path, _feature_h5, vocab, mode,
                              image_h5=_image_h5,
                              max_samples=max_samples,
                              refexp_families=refexp_families,
                              image_idx_start_from=image_idx_start_from)
      return _dataset 


    def get_loader(self, _loader_kwargs, batch_size):
      _batch_lis = []
      import copy
      _tic_time = time.time()
      cur_dataset = self.get_dataset(copy.deepcopy(_loader_kwargs))
      len_dataset = len(cur_dataset)
      for i, item in enumerate(cur_dataset):
        _batch_lis.append(item)
        if i>= len_dataset:
          yield clevr_collate(_batch_lis)
          raise StopIteration
          break
        if len(_batch_lis) == batch_size:
          _toc_time = time.time()
          yield clevr_collate(_batch_lis)
          _batch_lis.clear()
          _tic_time = time.time()

  train_loader = TLoader(train_loader_kwargs, args.batch_size)
  train_len    = len(train_loader.get_dataset(train_loader_kwargs))
  val_loader = TLoader(val_loader_kwargs, args.batch_size)
  val_len      = len(val_loader.get_dataset(val_loader_kwargs))

  train_loop(args, train_loader, train_len, val_loader, val_len)

  if args.use_local_copies == 1 and args.cleanup_local_copies == 1:
    os.remove('/tmp/train_refexps.h5')
    os.remove('/tmp/train_features.h5')
    os.remove('/tmp/val_refexps.h5')
    os.remove('/tmp/val_features.h5')
def main(args):
    global AVAILABLE_OBJECTS, AVAILABLE_MATERIALS, AVAILABLE_SIZES, AVAILABLE_COLOURS
    global NUM_AVAILABLE_OBJECTS, NUM_AVAILABLE_MATERIALS, NUM_AVAILABLE_SIZES, NUM_AVAILABLE_COLOURS
    global obj_probs, material_probs, colour_probs, size_probs
    global save_directory
    model = None

    try:
        with open(args.properties_json, 'r') as f:
            properties = json.load(f)
            for name, rgb in properties['colors'].items():
                rgba = [float(c) / 255.0 for c in rgb] + [1.0]
                AVAILABLE_COLOURS.append((name, rgba))
            AVAILABLE_MATERIALS = [(v, k)
                                   for k, v in properties['materials'].items()]
            AVAILABLE_OBJECTS = [(v, k)
                                 for k, v in properties['shapes'].items()]
            AVAILABLE_SIZES = list(properties['sizes'].items())

            NUM_AVAILABLE_OBJECTS = len(AVAILABLE_OBJECTS)
            NUM_AVAILABLE_MATERIALS = len(AVAILABLE_MATERIALS)
            NUM_AVAILABLE_SIZES = len(AVAILABLE_SIZES)
            NUM_AVAILABLE_COLOURS = len(AVAILABLE_COLOURS)

            # categorical probabilities
            obj_probs = torch.ones(
                NUM_AVAILABLE_OBJECTS) / NUM_AVAILABLE_OBJECTS
            material_probs = torch.ones(
                NUM_AVAILABLE_MATERIALS) / NUM_AVAILABLE_MATERIALS
            colour_probs = torch.ones(
                NUM_AVAILABLE_COLOURS) / NUM_AVAILABLE_COLOURS
            size_probs = torch.ones(NUM_AVAILABLE_SIZES) / NUM_AVAILABLE_SIZES
    except:
        print("Unable to open properties file (properties_json argument)")
        exit()

    # OOD extrapolation: add object (out of training set)
    if args.out_of_distribution == 1:
        AVAILABLE_OBJECTS.append(('Cone', 'cone'))
        NUM_AVAILABLE_OBJECTS += 1
        obj_probs = torch.ones(NUM_AVAILABLE_OBJECTS) / NUM_AVAILABLE_OBJECTS
    elif args.out_of_distribution == 2:
        AVAILABLE_OBJECTS.append(('Corgi', 'corgi'))
        NUM_AVAILABLE_OBJECTS += 1
        obj_probs = torch.ones(NUM_AVAILABLE_OBJECTS) / NUM_AVAILABLE_OBJECTS
    # adversarial or OOD extrapolation: remove object
    if args.remove_object_type != None:
        NEW_AVAILABLE_OBJECTS = []
        for i in range(len(AVAILABLE_OBJECTS)):
            _, object_name = AVAILABLE_OBJECTS[i]
            if object_name != args.remove_object_type:
                NEW_AVAILABLE_OBJECTS.append(AVAILABLE_OBJECTS[i])
        AVAILABLE_OBJECTS = NEW_AVAILABLE_OBJECTS
        NUM_AVAILABLE_OBJECTS = len(AVAILABLE_OBJECTS)
        obj_probs = torch.ones(NUM_AVAILABLE_OBJECTS) / NUM_AVAILABLE_OBJECTS

    if args.save_dir != None:
        save_directory = args.save_dir
    if args.baseline_model is not None:
        print('Loading baseline model from ', args.baseline_model)
        model, _ = utils.load_baseline(args.baseline_model)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            model.rnn.expand_vocab(new_vocab['question_token_to_idx'])
    elif args.program_generator is not None and args.execution_engine is not None:
        print('Loading program generator from ', args.program_generator)
        program_generator, _ = utils.load_program_generator(
            args.program_generator)
        print('Loading execution engine from ', args.execution_engine)
        execution_engine, _ = utils.load_execution_engine(
            args.execution_engine, verbose=False)
        if args.vocab_json is not None:
            new_vocab = utils.load_vocab(args.vocab_json)
            program_generator.expand_encoder_vocab(
                new_vocab['question_token_to_idx'])
        model = (program_generator, execution_engine)
    else:
        print(
            'Must give either --baseline_model or --program_generator and --execution_engine'
        )
        return

    print("Calling inference!")
    random_latent = generate_random_latent(num_objects=args.num_objects)
    print(random_latent)
    if args.prob_test == 0:
        print("Running Metropolis Hastings (one constraint)")
        metropolis_hastings(initial_proposal=random_latent,
                            num_iters=int(args.num_iters),
                            std=0.05,
                            args=args,
                            model=model,
                            target_class=args.class_a,
                            num_objects=args.num_objects,
                            output_csv=args.output_csv,
                            test_name=args.test_name)
    if args.prob_test == 1:
        print("Running Metropolis Hastings (two constraints)")
        target_classes = [args.class_a, args.class_b]
        metropolis_hastings_two_classes(initial_proposal=random_latent,
                                        num_iters=int(args.num_iters),
                                        std=0.05,
                                        args=args,
                                        model=model,
                                        target_classes=target_classes,
                                        num_objects=args.num_objects)
    if args.prob_test == 2:
        print("Running Rejection Sampling (one constraint)")
        rejection_sampling(initial_proposal=random_latent,
                           num_iters=int(args.num_iters),
                           args=args,
                           model=model,
                           target_class=args.class_a,
                           num_objects=args.num_objects,
                           output_csv=args.output_csv,
                           test_name=args.test_name)
Beispiel #13
0
def train_loop(args, train_loader, val_loader):
    vocab = utils.load_vocab(args.vocab_json)
    program_generator, pg_kwargs, pg_optimizer = None, None, None
    execution_engine, ee_kwargs, ee_optimizer = None, None, None
    baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
    baseline_type = None

    pg_best_state, ee_best_state, baseline_best_state = None, None, None

    # pg_checkpoint = ModelCheckpoint(args.checkpoint_path,
    #                              monitor='val_accuracy',
    #                              verbose=1,
    #                              save_best_only=True,
    #                              mode='min',
    #                              load_weights_on_restart=True)
    # ee_checkpoint = ModelCheckpoint(args.checkpoint_path,
    #                                 monitor='val_accuracy',
    #                                 verbose=1,
    #                                 save_best_only=True,
    #                                 mode='min',
    #                                 load_weights_on_restart=True)
    pg_checkpoint_dir = './pg_training_checkpoints'
    pg_checkpoint_prefix = os.path.join(pg_checkpoint_dir, "ckpt")
    ee_checkpoint_dir = './ee_training_checkpoints'
    ee_checkpoint_prefix = os.path.join(ee_checkpoint_dir, "ckpt")
    # Set up model
    if args.model_type == 'PG' or args.model_type == 'PG+EE':
        program_generator, pg_kwargs = get_program_generator(args)
        pg_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the program generator:')
        checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer,
                                         program_generator=program_generator)
        # program_generator.build(input_shape=[46,])
        # program_generator.compile(optimizer='adam', loss='mse')
        print(program_generator)
    if args.model_type == 'EE' or args.model_type == 'PG+EE':
        execution_engine, ee_kwargs = get_execution_engine(args)
        ee_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the execution engine:')
        print(execution_engine)
        checkpoint = tf.train.Checkpoint(optimizer=ee_optimizer,
                                         execution_engine=execution_engine)

    stats = {
        'train_losses': [],
        'train_rewards': [],
        'train_losses_ts': [],
        'train_accs': [],
        'val_accs': [],
        'val_accs_ts': [],
        'best_val_acc': -1,
        'model_t': 0,
    }
    t, epoch, reward_moving_average = 0, 0, 0
    batch_size = 64

    # set_mode('train', [program_generator, execution_engine, baseline_model])

    print('train_loader has %d samples' % len(train_loader))
    print('val_loader has %d samples' % len(val_loader))
    train_data_load = batch_creater(train_loader, batch_size, False)
    val_data_load = batch_creater(val_loader, batch_size, False)
    print("train data load length :", len(train_data_load))

    while t < args.num_iterations:
        total_loss = 0
        epoch += 1
        print('Starting epoch %d' % epoch)
        #print("value of t :", t)
        for run_num, batch in enumerate(train_data_load):
            batch_loss = 0
            t += 1
            questions, _, feats, answers, programs, _ = to_tensor(
                batch[0]), batch[1], to_tensor(batch[2]), to_tensor(
                    batch[3]), to_tensor(batch[4]), batch[5]

            questions_var = tf.Variable(questions)
            feats_var = tf.Variable(feats, trainable=True)
            answers_var = tf.Variable(answers, trainable=True)
            if programs[0] is not None:
                programs_var = tf.Variable(programs, trainable=True)

            reward = None
            if args.model_type == 'PG':
                # Train program generator with ground-truth programs+++
                with tf.GradientTape() as tape:
                    batch_loss = program_generator(questions_var, programs_var)
                total_loss += batch_loss
                variables = program_generator.variables
                gradients = tape.gradient(batch_loss, variables)
                pg_optimizer.apply_gradients(zip(gradients), variables)

            if args.model_type == 'EE':
                # Train program generator with ground-truth programs+++
                feats = tf.transpose(feats_var, perm=[0, 2, 3, 1])
                feats_var = tf.Variable(feats)
                with tf.GradientTape() as tape:
                    scores = execution_engine(feats_var, programs_var)
                    answers_var = tf.dtypes.cast(answers_var, dtype=tf.int32)
                    batch_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=scores, labels=answers_var))
                total_loss += batch_loss
                grads = tape.gradient(batch_loss,
                                      execution_engine.trainable_variables)
                gradients = [
                    grad if grad is not None else tf.zeros_like(var) for var,
                    grad in zip(execution_engine.trainable_variables, grads)
                ]
                #TODO Might need some changes for gradients

                ee_optimizer.apply_gradients(
                    zip(gradients, execution_engine.trainable_variables))

            elif args.model_type == 'PG+EE':
                print("in PG EE -----------------")
                feats = tf.transpose(feats_var, perm=[0, 2, 3, 1])
                feats_var = tf.Variable(feats)
                with tf.GradientTape() as pg_tape, tf.GradientTape(
                ) as ee_tape:
                    programs_pred = program_generator.reinforce_sample(
                        questions_var)
                    print("shape of programs_pred : ", programs_pred.shape)
                    scores = execution_engine(feats_var, programs_pred)
                    answers_var = tf.dtypes.cast(answers_var, dtype=tf.int32)
                    batch_loss = tf.reduce_mean(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            logits=scores, labels=answers_var))
                    #_, preds = scores.data.max(1)
                    print("dim of score :", scores.shape)
                    preds = tf.math.reduce_max(scores, axis=1, keepdims=True)
                    print("dim of pred :", preds.shape)
                    # raw_reward = (preds == answers).float()
                    raw_reward = tf.cast((preds == answers), dtype=tf.float32)
                    reward_moving_average *= args.reward_decay
                    reward_moving_average += (
                        1.0 - args.reward_decay) * raw_reward.numpy().mean()
                    centered_reward = raw_reward - reward_moving_average

                if args.train_execution_engine == 1:
                    grads = ee_tape.gradient(
                        batch_loss, execution_engine.trainable_variables)
                    gradients = [
                        grad if grad is not None else tf.zeros_like(var)
                        for var, grad in zip(
                            execution_engine.trainable_variables, grads)
                    ]
                    # TODO Might need some changes for gradients

                    ee_optimizer.apply_gradients(
                        zip(gradients, execution_engine.trainable_variables))

                if args.train_program_generator == 1:
                    loss, multinomial_outputs = program_generator.reinforce_backward(
                        centered_reward)
                    multinomial_outputs = tf.concat(multinomial_outputs, 0)
                    multinomial_outputs = tf.Variable(multinomial_outputs)
                    print("multi op shape new : ", multinomial_outputs.shape)
                    grads = pg_tape.gradient(loss, multinomial_outputs)
                    pg_optimizer.apply_gradients(grads, multinomial_outputs)
            print('Epoch {} Batch No. {} Loss {:.4f}'.format(
                epoch, run_num, batch_loss.numpy()))

            # if t == args.num_iterations:
            #     break

            if t % (args.record_loss_every * 2) == 0:
                #print(t, batch_loss)
                stats['train_losses'].append(batch_loss)
                stats['train_losses_ts'].append(t)
                if reward is not None:
                    stats['train_rewards'].append(reward)

            if t % args.checkpoint_every == 0:
                print('Checking training accuracy ... ')
                train_acc = check_accuracy(args, program_generator,
                                           execution_engine, train_data_load)
                print('train accuracy is', train_acc)
                print('Checking validation accuracy ...')
                val_acc = check_accuracy(args, program_generator,
                                         execution_engine, val_data_load)
                print('val accuracy is ', val_acc)
                stats['train_accs'].append(train_acc)
                stats['val_accs'].append(val_acc)
                stats['val_accs_ts'].append(t)

                if val_acc > stats['best_val_acc']:
                    stats['best_val_acc'] = val_acc
                    stats['model_t'] = t
                    if args.model_type == 'PG':
                        checkpoint = tf.train.Checkpoint(
                            optimizer=pg_optimizer, model=program_generator)
                        checkpoint.save(file_prefix=pg_checkpoint_prefix)
                    if args.model_type == 'EE':
                        checkpoint = tf.train.Checkpoint(
                            optimizer=ee_optimizer, model=execution_engine)
                        checkpoint.save(file_prefix=ee_checkpoint_prefix)

        if t == args.num_iterations:
            break
def train_loop(args, train_loader, val_loader):
  vocab = utils.load_vocab(args.vocab_json)
  program_generator, pg_kwargs, pg_optimizer = None, None, None
  execution_engine, ee_kwargs, ee_optimizer = None, None, None
  baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
  baseline_type = None

  pg_best_state, ee_best_state, baseline_best_state = None, None, None

  # Set up model
  if args.model_type == 'PG' or args.model_type == 'PG+EE' or args.model_type == 'PG+EE+GQNT':
    program_generator, pg_kwargs = get_program_generator(args)
    pg_optimizer = torch.optim.Adam(program_generator.parameters(),
                                    lr=args.learning_rate)
    print('Here is the program generator:')
    print(program_generator)
  if args.model_type == 'EE' or args.model_type == 'PG+EE'or args.model_type == 'PG+EE+GQNT':
    execution_engine, ee_kwargs = get_execution_engine(args)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.multi_gpu:
      execution_engine = torch.nn.DataParallel(execution_engine)
    elif device is not "cpu":
      execution_engine = execution_engine.cuda()

    ee_optimizer = torch.optim.Adam(execution_engine.parameters(),
                                    lr=args.learning_rate)
    print('Here is the execution engine:')
    print(execution_engine)
  if args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
    baseline_model, baseline_kwargs = get_baseline_model(args)
    params = baseline_model.parameters()
    if args.baseline_train_only_rnn == 1:
      params = baseline_model.rnn.parameters()
    baseline_optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    print('Here is the baseline model')
    print(baseline_model)
    baseline_type = args.model_type
  if args.model_type == 'PG+EE+GQNT':
    kwargs = {
      'n_channels': 3,
      'v_dim': 224
    }
    model = MMModel(**kwargs)
    params = model.parameters()
    baseline_optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    print('Here is the GQNT model')
    print(model)

  loss_fn_1 = torch.nn.BCEWithLogitsLoss().cuda()
  loss_fn_2 = torch.nn.MSELoss()

  stats = {
    'train_losses': [], 'train_rewards': [], 'train_losses_ts': [],
    'train_accs': [], 'val_accs': [], 'val_accs_ts': [],
    'best_val_acc': -1, 'model_t': 0,
  }
  t, epoch, reward_moving_average = 0, 0, 0

  set_mode('train', [program_generator, execution_engine, baseline_model])

  print('train_loader has %d samples' % len(train_loader.dataset))
  print('val_loader has %d samples' % len(val_loader.dataset))

  while t < args.num_iterations:
    epoch += 1
    print('Starting epoch %d' % epoch)
    start = time.time()
    for batch in train_loader:
      # print("data loader: " + str(time.time() - start))
      start_batch = time.time()

      t += 1
      questions, images, feats, answers, programs, _, ocr_tokens = batch
      # print("mean answer value" + str((answers.sum() / float(len(answers))).item()))
      questions_var = Variable(questions.to(device))
      if programs[0] is not None:
        programs_var = Variable(programs.to(device))

      reward = None
      if args.model_type == 'PG':
        # Train program generator with ground-truth programs
        pg_optimizer.zero_grad()
        loss = program_generator(questions_var, programs_var)
        loss.backward()
        pg_optimizer.step()
      elif args.model_type == 'EE':
        # Train execution engine with ground-truth programs
        start = time.time()

        ee_optimizer.zero_grad()
        if images[0] is not None:
          images_var = Variable(images.to(device))
        else:
          feats_var = Variable(feats.to(device))
        answers_var = Variable(answers.to(device))
        text_embs = process_tokens(ocr_tokens).to(device)
        # text_embs = torch.cat([chars, locs], dim=1).reshape(-1).to(device)


        # print("OCR loading / processing + put stuff on cuda: " + str(time.time() - start))
        start = time.time()
        scores = execution_engine(feats_var, programs_var, text_embs)

        # print("Total Resnet + BiLSTM: " + str(time.time() - start))
        start = time.time()
        char_loss = loss_fn_1(scores, text_embs.reshape((-1,84)))
        loc_loss = loss_fn_2(text_embs[:, :, 26:], scores.reshape(-1, 3, 28)[:, :, 26:])
        loss = torch.sum(char_loss, loc_loss)

        loss.backward()
        ee_optimizer.step()
        # print("Optimization Step: " + str(time.time() - start))

      elif args.model_type in ['LSTM', 'CNN+LSTM', 'CNN+LSTM+SA']:
        baseline_optimizer.zero_grad()
        baseline_model.zero_grad()
        scores = baseline_model(questions_var, feats_var)
        loss = loss_fn(scores, answers_var)
        loss.backward()
        baseline_optimizer.step()
      elif args.model_type == 'PG+EE':
        programs_pred = program_generator.reinforce_sample(questions_var)
        scores = execution_engine(feats_var, programs_pred)

        loss = loss_fn(scores, answers_var)
        _, preds = scores.data.cpu().max(1)
        raw_reward = (preds == answers).float()
        reward_moving_average *= args.reward_decay
        reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean()
        centered_reward = raw_reward - reward_moving_average

        if args.train_execution_engine == 1:
          ee_optimizer.zero_grad()
          loss.backward()
          ee_optimizer.step()

        if args.train_program_generator == 1:
          pg_optimizer.zero_grad()
          program_generator.reinforce_backward(centered_reward.cuda())
          pg_optimizer.step()
      elif args.model_type == 'PG+EE+GQNT':
        programs_pred = program_generator.reinforce_sample(questions_var)
        baseline_optimizer.zero_grad()
        model.zero_grad()
        chars, locs = process_tokens(ocr_tokens)
        text_embs = torch.cat([chars, locs], dim=1).to(device)
        scores = execution_engine(feats_var, programs_pred, text_embs)
        loss = loss_fn(scores, answers_var)
        loss.backward()
        ee_optimizer.step()
      # print("total batch time (without data loading): " + str(time.time() - start_batch))

      if t % args.record_loss_every == 0:
        print(t, loss.data.item())
        stats['train_losses'].append(loss.data.item())
        stats['train_losses_ts'].append(t)
        if reward is not None:
          stats['train_rewards'].append(reward)

      if t % args.checkpoint_every == 0:
        print('Checking training accuracy ... ')
        train_acc, train_loc_err = check_accuracy(args, program_generator, execution_engine,
                                   baseline_model, train_loader)
        print('train accuracy is', train_acc)
        print('train loc error is', train_loc_err)
        print('Checking validation accuracy ...')
        val_acc, val_loc_err = check_accuracy(args, program_generator, execution_engine,
                                 baseline_model, val_loader)
        print('val accuracy is ', val_acc)
        print('val loc error is', val_loc_err)

        stats['train_accs'].append(train_acc)
        stats['val_accs'].append(val_acc)
        stats['val_accs_ts'].append(t)

        if val_acc > stats['best_val_acc']:
          stats['best_val_acc'] = val_acc
          stats['model_t'] = t
          best_pg_state = get_state(program_generator)
          best_ee_state = get_state(execution_engine)
          best_baseline_state = get_state(baseline_model)

        checkpoint = {
          'args': args.__dict__,
          'program_generator_kwargs': pg_kwargs,
          'program_generator_state': best_pg_state,
          'execution_engine_kwargs': ee_kwargs,
          'execution_engine_state': best_ee_state,
          'baseline_kwargs': baseline_kwargs,
          'baseline_state': best_baseline_state,
          'baseline_type': baseline_type,
          'vocab': vocab
        }
        for k, v in stats.items():
          checkpoint[k] = v
        print('Saving checkpoint to %s' % args.checkpoint_path)
        torch.save(checkpoint, args.checkpoint_path)
        del checkpoint['program_generator_state']
        del checkpoint['execution_engine_state']
        del checkpoint['baseline_state']
        with open(args.checkpoint_path + '.json', 'w') as f:
          json.dump(checkpoint, f)

      if t == args.num_iterations:
        break