def train_val_augment(): """ Train the listener with the augmented data """ setup() # Create a batch training environment that will also preprocess text vocab = read_vocab(train_vocab) tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput) # Load the env img features feat_dict = read_img_features(features) featurized_scans = set( [key.split("_")[0] for key in list(feat_dict.keys())]) # Load the augmentation data if args.upload: aug_path = get_sync_dir(os.path.join(args.upload_path, args.aug)) else: aux_path = os.path.join(args.R2R_Aux_path, args.aug) # Create the training environment train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) aug_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=[aug_path], tokenizer=tok, name='aug') # Printing out the statistics of the dataset stats = train_env.get_statistics() print("The training data_size is : %d" % train_env.size()) print("The average instruction length of the dataset is %0.4f." % (stats['length'])) print("The average action length of the dataset is %0.4f." % (stats['path'])) stats = aug_env.get_statistics() print("The augmentation data size is %d" % aug_env.size()) print("The average instruction length of the dataset is %0.4f." % (stats['length'])) print("The average action length of the dataset is %0.4f." % (stats['path'])) # Setup the validation data val_envs = { split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok), Evaluation([split], featurized_scans, tok)) for split in ['train', 'val_seen', 'val_unseen'] } # Start training train(train_env, tok, args.iters, val_envs=val_envs, aug_env=aug_env)
def train(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None): writer = SummaryWriter(logdir=log_dir) listner = Seq2SeqAgent(train_env, "", tok, args.maxAction) speaker = None if args.self_train: speaker = Speaker(train_env, listner, tok) if args.speaker is not None: if args.upload: print("Load the speaker from %s." % args.speaker) speaker.load( get_sync_dir(os.path.join(args.upload_path, args.speaker))) else: print("Load the speaker from %s." % args.speaker) speaker.load(os.path.join(args.R2R_Aux_path, args.speaker)) start_iter = 0 if args.load is not None: if args.upload: refs_paths = get_outputs_refs_paths()['experiments'][0] print(refs_paths) load_model = os.path.join(refs_paths, args.load) print(load_model) print("LOAD THE listener from %s" % load_model) start_iter = listner.load(load_model) else: print("LOAD THE listener from %s" % args.load) start_iter = listner.load( os.path.join(args.R2R_Aux_path, args.load)) start = time.time() best_val = { 'val_seen': { "accu": 0., "state": "", 'update': False }, 'val_unseen': { "accu": 0., "state": "", 'update': False } } if args.fast_train: log_every = 40 for idx in range(start_iter, start_iter + n_iters, log_every): listner.logs = defaultdict(list) interval = min(log_every, start_iter + n_iters - idx) iter = idx + interval # Train for log_every interval if aug_env is None: # The default training process listner.env = train_env listner.train(interval, feedback=feedback_method) # Train interval iters else: if args.accumulate_grad: for _ in range(interval // 2): listner.zero_grad() listner.env = train_env # Train with GT data args.ml_weight = 0.2 listner.accumulate_gradient(feedback_method) listner.env = aug_env # Train with Back Translation args.ml_weight = 0.6 # Sem-Configuration listner.accumulate_gradient(feedback_method, speaker=speaker) listner.optim_step() else: for _ in range(interval // 2): # Train with GT data listner.env = train_env args.ml_weight = 0.2 listner.train(1, feedback=feedback_method) # Train with Back Translation listner.env = aug_env args.ml_weight = 0.6 listner.train(1, feedback=feedback_method, speaker=speaker) # Log the training stats to tensorboard total = max(sum(listner.logs['total']), 1) # import pdb; pdb.set_trace() # length_rl == length_ml ? entropy length assert (max(len(listner.logs['rl_loss']), 1) == max(len(listner.logs['ml_loss']), 1)) max_rl_length = max(len(listner.logs['critic_loss']), 1) log_length = max(len(listner.logs['rl_loss']), 1) rl_loss = sum(listner.logs['rl_loss']) / log_length ml_loss = sum(listner.logs['ml_loss']) / log_length critic_loss = sum(listner.logs['critic_loss'] ) / log_length #/ length / args.batchSize spe_loss = sum(listner.logs['spe_loss']) / log_length pro_loss = sum(listner.logs['pro_loss']) / log_length mat_loss = sum(listner.logs['mat_loss']) / log_length fea_loss = sum(listner.logs['fea_loss']) / log_length ang_loss = sum(listner.logs['ang_loss']) / log_length entropy = sum( listner.logs['entropy']) / log_length #/ length / args.batchSize predict_loss = sum(listner.logs['us_loss']) / log_length writer.add_scalar("loss/rl_loss", rl_loss, idx) writer.add_scalar("loss/ml_loss", ml_loss, idx) writer.add_scalar("policy_entropy", entropy, idx) writer.add_scalar("loss/spe_loss", spe_loss, idx) writer.add_scalar("loss/pro_loss", pro_loss, idx) writer.add_scalar("loss/mat_loss", mat_loss, idx) writer.add_scalar("loss/fea_loss", fea_loss, idx) writer.add_scalar("loss/ang_loss", ang_loss, idx) writer.add_scalar("total_actions", total, idx) writer.add_scalar("max_rl_length", max_rl_length, idx) writer.add_scalar("loss/critic", critic_loss, idx) writer.add_scalar("loss/unsupervised", predict_loss, idx) print("total_actions", total) print("max_rl_length", max_rl_length) # Run validation loss_str = "" for env_name, (env, evaluator) in val_envs.items(): listner.env = env # Get validation loss under the same conditions as training iters = None if args.fast_train or env_name != 'train' else 20 # 20 * 64 = 1280 # Get validation distance from goal under test evaluation conditions listner.test(use_dropout=False, feedback='argmax', iters=iters) result = listner.get_results() score_summary, _ = evaluator.score(result) loss_str += "%s " % env_name for metric, val in score_summary.items(): if metric in ['success_rate']: loss_str += ', %s: %.4f' % (metric, val) writer.add_scalar("%s/accuracy" % env_name, val, idx) if env_name in best_val: if val > best_val[env_name]['accu']: best_val[env_name]['accu'] = val best_val[env_name]['update'] = True if metric in ['spl']: writer.add_scalar("%s/spl" % env_name, val, idx) loss_str += ', %s: %.4f' % (metric, val) loss_str += '\n' loss_str += '\n' for env_name in best_val: if best_val[env_name]['update']: best_val[env_name]['state'] = 'Iter %d \n%s' % (iter, loss_str) best_val[env_name]['update'] = False file_dir = os.path.join(output_dir, "snap", args.name, "state_dict", "best_%s" % (env_name)) listner.save(idx, file_dir) print(('%s (%d %d%%) \n%s' % (timeSince(start, float(iter) / n_iters), iter, float(iter) / n_iters * 100, loss_str))) if iter % 1000 == 0: print("BEST RESULT TILL NOW") for env_name in best_val: print(env_name, best_val[env_name]['state']) if iter % args.save_iter == 0: file_dir = os.path.join(output_dir, "snap", args.name, "state_dict", "Iter_%06d" % (iter)) listner.save(idx, file_dir)
from polyaxon_client.tracking import get_outputs_refs_paths if args.train == 'validlistener' and args.upload: refs_paths = get_outputs_refs_paths()['experiments'][0] print(refs_paths) load_model = os.path.join(refs_paths, args.load) print(load_model) import warnings warnings.filterwarnings("ignore") from tensorboardX import SummaryWriter from polyaxon_client.tracking import get_outputs_path if args.upload: train_vocab = get_sync_dir(os.path.join(args.upload_path, args.TRAIN_VOCAB)) trainval_vocab = get_sync_dir( os.path.join(args.upload_path, args.TRAINVAL_VOCAB)) features = get_sync_dir( os.path.join(args.upload_path, args.IMAGENET_FEATURES)) output_dir = get_outputs_path() log_dir = os.path.join(output_dir, "snap", args.name) if not os.path.exists(log_dir): os.makedirs(log_dir) sparse_obj_feat = get_sync_dir( os.path.join(args.upload_path, args.SPARSE_OBJ_FEATURES)) dense_obj_feat1 = get_sync_dir( os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES1)) dense_obj_feat2 = get_sync_dir( os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES2)) bbox = get_sync_dir(os.path.join(args.upload_path, args.BBOX_FEATURES))
TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt' TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt' IMAGENET_FEATURES = 'lyx/img_features/ResNet-152-imagenet.tsv' PLACE365_FEATURES = 'lyx/img_features/ResNet-152-places365.tsv' SPARSE_OBJ_FEATURES = 'lyx/obj_features/%s/panorama_objs_Features_nms_%s.npy' % ( args.objdir, args.objdir) DENSE_OBJ_FEATURES1 = 'lyx/obj_features/%s/panorama_objs_DenseFeatures_nms1_%s.npy' % ( args.objdir, args.objdir) DENSE_OBJ_FEATURES2 = 'lyx/obj_features/%s/panorama_objs_DenseFeatures_nms2_%s.npy' % ( args.objdir, args.objdir) if args.features == 'imagenet': features = get_sync_dir(IMAGENET_FEATURES) sparse_obj_feat = get_sync_dir(SPARSE_OBJ_FEATURES) dense_obj_feat1 = get_sync_dir(DENSE_OBJ_FEATURES1) dense_obj_feat2 = get_sync_dir(DENSE_OBJ_FEATURES2) if args.fast_train: name, ext = os.path.splitext(features) features = name + "-fast" + ext feedback_method = args.feedback # teacher or sample print(args) def train_speaker(train_env, tok, n_iters, log_every=500, val_envs={}): writer = SummaryWriter(logdir=log_dir)
from polyaxon_client.tracking import get_outputs_refs_paths if args.train == 'validlistener' and args.upload: refs_paths = get_outputs_refs_paths()['experiments'][0] print(refs_paths) load_model = os.path.join(refs_paths, args.load) print(load_model) import warnings warnings.filterwarnings("ignore") from tensorboardX import SummaryWriter from polyaxon_client.tracking import get_outputs_path if args.upload: train_vocab = get_sync_dir(os.path.join(args.upload_path, args.TRAIN_VOCAB)) trainval_vocab = get_sync_dir( os.path.join(args.upload_path, args.TRAINVAL_VOCAB)) features = get_sync_dir( os.path.join(args.upload_path, args.IMAGENET_FEATURES)) output_dir = get_outputs_path() log_dir = os.path.join(output_dir, "snap", args.name) if not os.path.exists(log_dir): os.makedirs(log_dir) # sparse_obj_feat = get_sync_dir(os.path.join(args.upload_path, args.SPARSE_OBJ_FEATURES)) # dense_obj_feat1 = get_sync_dir(os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES1)) # dense_obj_feat2 = get_sync_dir(os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES2)) # bbox = get_sync_dir(os.path.join(args.upload_path, args.BBOX_FEATURES)) else: train_vocab = os.path.join(args.R2R_Aux_path, args.TRAIN_VOCAB)