def main(args): log_dir = osp.join(CONFIG.DIRS.LOG_DIR, "debug") # Initialize experiment config, device = experiment_utils.init_experiment( log_dir, CONFIG, args.config_path, transient=True) # Load model, data loaders and trainer. debug = { "sample_sequential": not args.shuffle, "augment": args.augment, "num_workers": 0, } model, optimizer, loaders, trainer, _ = experiment_utils.get_factories( config, device, debug=debug) num_ctx_frames = config.SAMPLING.NUM_CONTEXT_FRAMES num_frames = config.SAMPLING.NUM_FRAMES_PER_SEQUENCE try: loader = loaders["pretrain_train"] for batch_idx, batch in enumerate(loader): if batch_idx > 4: break print(f"Batch #{batch_idx}") frames = batch["frames"] b, _, c, h, w = frames.shape frames = frames.view(b, num_frames, num_ctx_frames, c, h, w) # # To visualize baseline batch. # for i in range(frames.shape[1]): # grid_img = torchvision.utils.make_grid(frames[:, i, -1], nrow=5) # plt.imshow(grid_img.permute(1, 2, 0)) # plt.show() for b in range(frames.shape[0]): print(f"\tBatch Item {b}") if args.viz_context: fig, axes = plt.subplots( num_ctx_frames, num_frames, constrained_layout=True ) for i in range(num_frames): for j in range(num_ctx_frames): axes[j, i].imshow( UnNormalize()(frames[b, i, j]).permute(1, 2, 0) ) for ax in axes.flatten(): ax.axis("off") plt.show() else: grid_img = torchvision.utils.make_grid( frames[b, :, -1], nrow=5) # grid_img = UnNormalize()(grid_img) plt.imshow(grid_img.permute(1, 2, 0)) plt.show() except KeyboardInterrupt: sys.exit()
def setup_embedder(log_dir, config_path): checkpoint_dir = os.path.join(log_dir, 'checkpoints') config, device = experiment_utils.init_experiment(log_dir, CONFIG, config_path) embedder, _, _, _, _ = experiment_utils.get_factories( config, device, debug={'augment': False}) checkpoint_manager = checkpoint.CheckpointManager( checkpoint.Checkpoint(embedder), checkpoint_dir, device) global_step = checkpoint_manager.restore_or_initialize() print(f"Restored model from checkpoint {global_step}") return embedder
def setup_embedder(log_dir, config_path): """Load the latest embedder checkpoint and dataloaders.""" checkpoint_dir = os.path.join(log_dir, 'checkpoints') config, device = experiment_utils.init_experiment(log_dir, CONFIG, config_path) embedder, _, loaders, _, _ = experiment_utils.get_factories( config, device, debug={'augment': False}) checkpoint_manager = checkpoint.CheckpointManager( checkpoint.Checkpoint(embedder), checkpoint_dir, device) global_step = checkpoint_manager.restore_or_initialize() l2_normalize = config.LOSS.L2_NORMALIZE_EMBEDDINGS print(f"Restored model from checkpoint {global_step}") return embedder, loaders, l2_normalize
def main(args): if torch.cuda.is_available(): device = torch.device("cuda") logging.info("Using GPU {}.".format( torch.cuda.get_device_name(device))) else: logging.info("No GPU found. Falling back to CPU.") device = torch.device("cpu") # initialize experiment opts = [ "SAMPLING.STRIDE_ALL_SAMPLER", args.stride, "ACTION_CLASS", [], ] config_path = osp.join(args.logdir, "config.yml") config, device = experiment_utils.init_experiment( args.logdir, CONFIG, config_path, opts, ) # load model and data loaders debug = {"sample_sequential": False, "augment": False, "labeled": "both"} tcc_model, _, loaders, _, _ = experiment_utils.get_factories(config, device, debug=debug) # load model checkpoint checkpoint.CheckpointManager.load_latest_checkpoint( checkpoint.Checkpoint(tcc_model), osp.join(config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir))), device, ) tcc_model.to(device) tcc_model.featurizer_net.eval() tcc_model.encoder_net.train() freeze_model(tcc_model.featurizer_net, True, True) # for name, param in tcc_model.named_parameters(): # if param.requires_grad: # print(name) # create LSTM model lstm = LSTMModel( in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE, lstm_dim=32, num_layers=1, ) lstm.to(device).train() optimizer = torch.optim.Adam( [ { "params": lstm.parameters() }, { "params": tcc_model.parameters(), "lr": 1e-4 }, ], lr=1e-3, ) # figure out max batch size that's # a multiple of the number of context # frames. # this is so we can support large videos # with many frames. lcm = tcc_model.num_ctx_frames max_batch_size = math.floor(128 / lcm) * lcm # test on query rand_corr, rand_num = test_query( device, lstm, tcc_model, max_batch_size, ( loaders["downstream_train"]["rms"], loaders["downstream_valid"]["rms"], ), args.l2_normalize, args.batch_size, ) rand_acc = rand_corr / rand_num print("Randomly initialized LSTM accuracy: {} ({}/{})".format( rand_acc, rand_corr, rand_num)) global_step = 0 complete = False try: while not complete: accs = [] for action_name, loader in loaders["downstream_train"].items(): tcc_model.encoder_net.train() lstm.train() if action_name != "rms": batch_embs, batch_labels = [], [] correct, num_x = 0, 0 set_trace() for batch_idx, batch in enumerate(loader): if batch_idx > 2: break if len(batch_embs) < args.batch_size: frames = batch["frames"] embs = embed( tcc_model, frames, device, max_batch_size, args.l2_normalize, ) batch_embs.append(embs[0]) batch_labels.extend(batch["success"]) if len(batch_embs) == args.batch_size or batch_idx == ( len(loader) - 1): print(len(batch_embs)) # sort list of embeddings by sequence length # in descending order idxs_sorted = np.argsort( [-len(x) for x in batch_embs]) batch_embs = [batch_embs[i] for i in idxs_sorted] batch_labels = torch.stack( [batch_labels[i] for i in idxs_sorted]) batch_labels = ( batch_labels.unsqueeze(1).float().to(device)) # forward through lstm and compute loss out = lstm(batch_embs) loss = F.binary_cross_entropy_with_logits( out, batch_labels) # backprop optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( lstm.parameters(), 1.0) optimizer.step() # compute accuracy pred = torch.sigmoid(out) > 0.5 correct += (pred.eq( batch_labels.view_as(pred)).sum().item()) num_x += len(pred) logging.info("{}: Loss: {:.6f}".format( global_step, loss.item())) global_step += 1 batch_embs, batch_labels = [], [] # exit if complete if global_step >= args.max_iters: complete = True break if num_x > 0: acc = correct / num_x logging.info("{} accuracy: {:.2f} ({}/{})".format( action_name, acc, correct, num_x)) accs.append(acc) print("Mean embodiment accuracy: {}".format(np.mean(accs))) query_corr, query_num = test_query( device, lstm, tcc_model, max_batch_size, ( loaders["downstream_train"]["rms"], loaders["downstream_valid"]["rms"], ), args.l2_normalize, args.batch_size, ) query_acc = query_corr / query_num print("Learner accuracy: {} ({}/{})".format( query_acc, query_corr, query_num)) except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting.") finally: pass
def main(args): if torch.cuda.is_available(): device = torch.device("cuda") logging.info("Using GPU {}.".format( torch.cuda.get_device_name(device))) else: logging.info("No GPU found. Falling back to CPU.") device = torch.device("cpu") # initialize experiment opts = [ "SAMPLING.STRIDE_ALL_SAMPLER", args.stride, "ACTION_CLASS", [], # "IMAGE_SIZE", # (216, 224), "DATASET", "embodied_glasses", ] config_path = osp.join(args.logdir, "config.yml") config, device = experiment_utils.init_experiment( args.logdir, CONFIG, config_path, opts, ) # load TCC model and checkpoint debug = { "sample_sequential": True, "augment": False, "labeled": "both", } tcc_model, _, loaders, _, _ = experiment_utils.get_factories(config, device, debug=debug) checkpoint.CheckpointManager.load_latest_checkpoint( checkpoint.Checkpoint(tcc_model), osp.join(config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir))), device, ) checkpoint_manager = checkpoint.CheckpointManager( checkpoint.Checkpoint(tcc_model), osp.join(config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir))), device, ) init_step = checkpoint_manager.restore_or_initialize() # tcc_model.to(device).eval() tcc_model.to(device) tcc_model.featurizer_net.eval() tcc_model.encoder_net.train() freeze_model(tcc_model.featurizer_net, True, True) # create LSTM model lstm = LSTMModel( in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE, lstm_dim=512, out_dim=1, num_layers=1, ) # lstm = MLPModel( # in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE, # hidden_dim=128, # out_dim=1, # ) lstm.to(device).train() # optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3, weight_decay=0) optimizer = torch.optim.Adam( [ { "params": lstm.parameters(), "weight_decay": 1e-5 }, { "params": tcc_model.parameters(), "lr": 1e-4 }, ], lr=1e-3, ) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1) # figure out max batch size that's # a multiple of the number of context # frames. # this is so we can support large videos # with many frames. lcm = tcc_model.num_ctx_frames max_batch_size = math.floor(128 / lcm) * lcm query_metric = test_query( device, lstm, tcc_model, max_batch_size, ( loaders["downstream_train"][args.query], loaders["downstream_valid"][args.query], ), args.l2_normalize, args.batch_size, ) conf_matrix = query_metric["confusion_matrix"] num_correct = conf_matrix[0, 0] + conf_matrix[1, 1] num_total = conf_matrix.ravel().sum() print( f"Random LSTM accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})" ) global_step = 0 max_acc = 0 complete = False losses, expert_acc, learner_acc = [], [], [] try: while not complete: accs = [] lstm.train() tcc_model.encoder_net.train() for action_name, loader in loaders["downstream_train"].items(): if action_name != args.query: batch_embs, batch_labels, batch_names = [], [], [] correct, num_x = 0, 0 for batch_idx, batch in enumerate(loader): if len(batch_embs) < args.batch_size: frames = batch["frames"] # b, t, c, h, w = frames.shape # frames = frames.view(b, t // tcc_model.num_ctx_frames, tcc_model.num_ctx_frames, c, h, w) # img = UnNormalize()(frames[:, -1, -1])[0].permute(1, 2, 0).cpu().numpy() # plt.imshow(img) # plt.show() # set_trace() batch_names.extend(batch["video_name"][0]) embs = embed( tcc_model, frames, device, max_batch_size, args.l2_normalize, ) batch_embs.append(embs[0]) batch_labels.extend(batch["success"]) if len(batch_embs) == args.batch_size or batch_idx == ( len(loader) - 1): # batch_embs, batch_labels = augment_batch(batch_embs, batch_labels) # sort list of embeddings by sequence length # in descending order idxs_sorted = np.argsort( [-len(x) for x in batch_embs]) batch_embs = [batch_embs[i] for i in idxs_sorted] batch_labels = torch.stack( [batch_labels[i] for i in idxs_sorted]) batch_labels = ( batch_labels.unsqueeze(1).float().to(device)) batch_names = [batch_names[i] for i in idxs_sorted] # forward through lstm and compute loss out = lstm(batch_embs) loss = F.binary_cross_entropy_with_logits( out, batch_labels) losses.append(loss.item()) # backprop optimizer.zero_grad() loss.backward() # torch.nn.utils.clip_grad_norm_(lstm.parameters(), 1.0) optimizer.step() scheduler.step() for param_group in optimizer.param_groups: print(global_step, param_group["lr"]) # compute accuracy with torch.no_grad(): probs = torch.sigmoid(out) pred = probs > 0.5 # corr = pred.eq(batch_labels) # pred_np = corr.flatten().cpu().numpy() # mistakes = [batch_names[i] for i in np.argwhere(pred_np == False).flatten().tolist()] # if global_step > 100: # for m in mistakes: # print(m) correct += (pred.eq( batch_labels.view_as(pred)).sum().item()) num_x += len(pred) logging.info("{}: Loss: {:.6f}".format( global_step, loss.item())) global_step += 1 batch_embs, batch_labels, batch_names = [], [], [] # exit if complete if global_step >= args.max_iters: complete = True break if num_x > 0: acc = correct / num_x logging.info("{} accuracy: {:.2f} ({}/{})".format( action_name, acc, correct, num_x)) accs.append(acc) mean_acc = np.mean(accs) expert_acc.append(mean_acc) print("Mean embodiment accuracy: {}".format(mean_acc)) plot_loss(losses, osp.join(args.logdir, "lstm_loss.png")) query_metric = test_query( device, lstm, tcc_model, max_batch_size, ( loaders["downstream_train"][args.query], loaders["downstream_valid"][args.query], ), args.l2_normalize, args.batch_size, ) learner_acc.append(query_metric["accuracy"]) plot_acc(expert_acc, learner_acc, osp.join(args.logdir, "lstm_acc.png")) if query_metric["accuracy"] > max_acc: max_acc = query_metric["accuracy"] conf_matrix = query_metric["confusion_matrix"] num_correct = conf_matrix[0, 0] + conf_matrix[1, 1] num_total = conf_matrix.ravel().sum() print( f"Learner accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})" ) except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting.") finally: conf_matrix = query_metric["confusion_matrix"] np.savetxt(osp.join(args.logdir, "confusion_matrix_lstm.txt"), conf_matrix) plot_auc(query_metric, osp.join(args.logdir, "auc_curve_lstm.png")) learner_acc.append(query_metric["accuracy"]) plot_acc(expert_acc, learner_acc, osp.join(args.logdir, "lstm_acc.png")) plot_loss(losses, osp.join(args.logdir, "lstm_loss.png")) checkpoint_manager.save(init_step + global_step) print("best acc achieved: {}".format(max_acc))
def main(args): if torch.cuda.is_available(): device = torch.device("cuda") logging.info("Using GPU {}.".format( torch.cuda.get_device_name(device))) else: logging.info("No GPU found. Falling back to CPU.") device = torch.device("cpu") # initialize experiment opts = [ "SAMPLING.STRIDE_ALL_SAMPLER", args.stride, ] config_path = osp.join(args.logdir, "config.yml") config, device = experiment_utils.init_experiment( args.logdir, CONFIG, config_path, opts, ) # load model and data loaders debug = { "sample_sequential": not args.shuffle, "augment": False, "labeled": None, # "both" if args.negative else "pos", } model, _, loaders, _, _ = experiment_utils.get_factories(config, device, debug=debug) loaders = (loaders["downstream_valid"] if args.split == "valid" else loaders["downstream_train"]) # load model checkpoint if args.model_ckpt is not None: checkpoint.Checkpoint(model).restore(args.model_ckpt, device) else: checkpoint.CheckpointManager.load_latest_checkpoint( checkpoint.Checkpoint(model), osp.join(config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir))), device, ) model.to(device).eval() # figure out max batch size that's # a multiple of the number of context # frames. # this is so we can support large videos # with many frames. lcm = model.num_ctx_frames max_batch_size = math.floor(128 / lcm) * lcm # create save folder save_path = osp.join( config.DIRS.DIR, args.save_path, osp.basename(osp.normpath(args.logdir)), "embs", ) file_utils.mkdir(save_path) # iterate over every class action pbar = tqdm.tqdm(loaders.items(), leave=False) for action_name, loader in pbar: msg = "embedding {}".format(action_name) pbar.set_description(msg) ( embeddings, seq_lens, steps, vid_frames, names, labels, phase_labels, ) = ([] for i in range(7)) for batch_idx, batch in enumerate(loader): if args.max_embs != -1 and batch_idx >= args.max_embs: break # unpack batch data frames = batch["frames"] chosen_steps = batch["frame_idxs"].to(device) seq_len = batch["video_len"].to(device) name = batch["video_name"][0] # label = batch["success"][0] # phase_label = None # if "phase_labels" in batch: # phase_label = batch["phase_labels"].to(device) # forward through model to compute embeddings with torch.no_grad(): if frames.shape[1] > max_batch_size: embs = [] for i in range(math.ceil(frames.shape[1] / max_batch_size)): sub_frames = frames[:, i * max_batch_size:(i + 1) * max_batch_size].to(device) sub_embs = model(sub_frames)["embs"] embs.append(sub_embs.cpu()) embs = torch.cat(embs, dim=1) else: embs = model(frames.to(device))["embs"] # store embeddings.append(embs.cpu().squeeze().numpy()) seq_lens.append(seq_len.cpu().squeeze().numpy()) steps.append(chosen_steps.cpu().squeeze().numpy()) # if phase_label is not None: # phase_labels.append(phase_label.cpu().squeeze().numpy()) names.append(name) # labels.append(label.item()) if args.keep_frames: frames = frames[0] frames = UnNormalize()(frames) frames = frames.view( embs.shape[1], config.SAMPLING.NUM_CONTEXT_FRAMES, *frames.shape[1:], ) vid_frames.append(frames.cpu().squeeze().numpy()[:, -1].transpose( 0, 2, 3, 1)) data = { "embs": embeddings, "seq_lens": seq_lens, "steps": steps, "names": names, # "labels": labels, } if args.keep_frames: data["frames"] = vid_frames # if phase_labels: # data["phase_labels"] = phase_labels np.save(osp.join(save_path, action_name), data)
def main(args): if torch.cuda.is_available(): device = torch.device("cuda") logging.info("Using GPU {}.".format( torch.cuda.get_device_name(device))) else: logging.info("No GPU found. Falling back to CPU.") device = torch.device("cpu") # initialize experiment opts = [ "SAMPLING.STRIDE_ALL_SAMPLER", args.stride, "ACTION_CLASS", [], "DATASET", "embodied_glasses", ] config_path = osp.join(args.logdir, "config.yml") config, device = experiment_utils.init_experiment( args.logdir, CONFIG, config_path, opts, ) # determine the number of context frames num_ctx_frames = config.SAMPLING.NUM_CONTEXT_FRAMES print(f"ctx frames: {num_ctx_frames}") # load TCC model and checkpoint debug = { "sample_sequential": True, "augment": False, "labeled": "both", } _, _, loaders, _, _ = experiment_utils.get_factories(config, device, debug=debug) # create model model = Model(in_dim=3, out_dim=1) model.to(device).train() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1) query_metric = test_query( device, model, ( loaders["downstream_train"][args.query], loaders["downstream_valid"][args.query], ), args.batch_size, num_ctx_frames, ) conf_matrix = query_metric["confusion_matrix"] num_correct = conf_matrix[0, 0] + conf_matrix[1, 1] num_total = conf_matrix.ravel().sum() print( f"Initial pretrained model accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})" ) global_step = 0 max_acc = 0 complete = False losses, expert_acc, learner_acc = [], [], [] try: while not complete: accs = [] model.train() for action_name, loader in loaders["downstream_train"].items(): if action_name != args.query: batch_embs, batch_labels, batch_names = [], [], [] correct, num_x = 0, 0 for batch_idx, batch in enumerate(loader): if len(batch_embs) < args.batch_size: frames = batch["frames"] b, t, c, h, w = frames.shape frames = frames.view(b, t // num_ctx_frames, num_ctx_frames, c, h, w) frame = frames[:, -1, -1] # img = UnNormalize()(frame)[0].permute(1, 2, 0).cpu().numpy() # plt.imshow(img) # plt.show() # set_trace() batch_names.extend(batch["video_name"][0]) batch_embs.append(frame) batch_labels.extend(batch["success"]) if len(batch_embs) == args.batch_size or batch_idx == ( len(loader) - 1): batch_embs = torch.cat(batch_embs, dim=0).to(device) batch_labels = (torch.stack( batch_labels).unsqueeze(1).float().to(device)) # forward through model and compute loss out = model(batch_embs) loss = F.binary_cross_entropy_with_logits( out, batch_labels) losses.append(loss.item()) # backprop optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() for param_group in optimizer.param_groups: print(global_step, param_group["lr"]) # compute accuracy with torch.no_grad(): probs = torch.sigmoid(out) pred = probs > 0.5 correct += (pred.eq( batch_labels.view_as(pred)).sum().item()) num_x += len(pred) logging.info("{}: Loss: {:.6f}".format( global_step, loss.item())) global_step += 1 batch_embs, batch_labels, batch_names = [], [], [] # exit if complete if global_step >= args.max_iters: complete = True break if num_x > 0: acc = correct / num_x logging.info("{} accuracy: {:.2f} ({}/{})".format( action_name, acc, correct, num_x)) accs.append(acc) mean_acc = np.mean(accs) expert_acc.append(mean_acc) print("Mean embodiment accuracy: {}".format(mean_acc)) plot_loss(losses, osp.join(args.logdir, "conv_loss.png")) query_metric = test_query( device, model, ( loaders["downstream_train"][args.query], loaders["downstream_valid"][args.query], ), args.batch_size, num_ctx_frames, ) learner_acc.append(query_metric["accuracy"]) plot_acc(expert_acc, learner_acc, osp.join(args.logdir, "conv_acc.png")) if query_metric["accuracy"] > max_acc: max_acc = query_metric["accuracy"] conf_matrix = query_metric["confusion_matrix"] num_correct = conf_matrix[0, 0] + conf_matrix[1, 1] num_total = conf_matrix.ravel().sum() print( f"Learner accuracy: {query_metric['accuracy']} ({num_correct}/{num_total})" ) except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting.") finally: conf_matrix = query_metric["confusion_matrix"] np.savetxt(osp.join(args.logdir, "confusion_matrix_conv.txt"), conf_matrix) plot_auc(query_metric, osp.join(args.logdir, "auc_curve_conv.png")) learner_acc.append(query_metric["accuracy"]) plot_acc(expert_acc, learner_acc, osp.join(args.logdir, "conv_acc.png")) plot_loss(losses, osp.join(args.logdir, "conv_loss.png")) print("best acc achieved: {}".format(max_acc))
def main(args): log_dir = osp.join(CONFIG.DIRS.LOG_DIR, args.experiment_name) # initialize experiment config, device = experiment_utils.init_experiment(log_dir, CONFIG, args.config_path) # setup logger logger = Logger(log_dir, args.resume) # load model, data loaders and trainer debug = {"sample_sequential": True, "num_workers": 4} ( model, optimizer, loaders, trainer, evaluators, ) = experiment_utils.get_factories(config, device, debug=debug) # create checkpoint manager checkpoint_manager = checkpoint.CheckpointManager( checkpoint.Checkpoint(model, optimizer), osp.join(config.DIRS.CKPT_DIR, args.experiment_name), device, ) global_step = checkpoint_manager.restore_or_initialize() epoch = int(global_step / len(loaders["pretrain_train"])) complete = False stopwatch = experiment_utils.Stopwatch() try: while not complete: logger.log_learning_rate(optimizer, global_step) for batch_idx, batch in enumerate(loaders["pretrain_train"]): if batch_idx >= args.num_samples: break # train one iteration loss = trainer.train_one_iter(batch, global_step) # save model checkpoint if not global_step % config.CHECKPOINT.SAVE_INTERVAL: checkpoint_manager.save(global_step) if not global_step % config.LOGGING.REPORT_INTERVAL: # train loss with model in eval mode train_eval_loss = trainer.eval_num_iters( loaders["pretrain_train"], (args.num_samples // config.BATCH_SIZE), ) losses = {"train": loss, "train_eval": train_eval_loss} logger.log_loss(losses, global_step) logging.info( "Iter[{}/{}] (Epoch {}), Train (eval) Loss: {:.3f}". format( global_step, config.TRAIN.MAX_ITERS, epoch, train_eval_loss.item(), )) if not global_step % config.LOGGING.EVAL_INTERVAL: for eval_name, evaluator in evaluators.items(): eval_metric_train = evaluator.evaluate( global_step, model, loaders["downstream_train"], device, args.num_samples, msg="train", ) metrics = {"train": eval_metric_train} logger.log_metric(metrics, global_step, eval_name) logging.info( "Iter[{}/{}] (Epoch {}) - {} train: {:.4f}".format( global_step, config.TRAIN.MAX_ITERS, epoch, eval_name, eval_metric_train["scalar"], )) # exit if complete global_step += 1 if global_step >= config.TRAIN.MAX_ITERS: complete = True break time_per_iter = stopwatch.elapsed() logging.info( "Iter[{}/{}] (Epoch {}), {:.1f}s/iter, Loss: {:.3f}". format( global_step, config.TRAIN.MAX_ITERS, epoch, time_per_iter, loss.item(), )) stopwatch.reset() epoch += 1 except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting.") finally: checkpoint_manager.save(global_step) logger.close()
def main(args): if torch.cuda.is_available(): device = torch.device("cuda") logging.info( "Using GPU {}.".format(torch.cuda.get_device_name(device)) ) else: logging.info("No GPU found. Falling back to CPU.") device = torch.device("cpu") # Initialize experiment. opts = [ # "SAMPLING.STRIDE_ALL_SAMPLER", # args.stride, "IMAGE_SIZE", (64, 64), "ACTION_CLASS", [args.query], "DATASET", "demo", ] config_path = osp.join(args.logdir, "config.yml") config, device = experiment_utils.init_experiment( args.logdir, CONFIG, config_path, opts, ) # Load TCC model and checkpoint. debug = { "sample_sequential": False, "augment": False, "labeled": None, } _, _, loaders, _, _ = experiment_utils.get_factories( config, device, debug=debug ) # checkpoint.CheckpointManager.load_latest_checkpoint( # checkpoint.Checkpoint(tcc_model), # osp.join( # config.DIRS.CKPT_DIR, osp.basename(osp.normpath(args.logdir)) # ), # device, # ) # Freeze TCC weights. # tcc_model.to(device) # tcc_model.featurizer_net.eval() # freeze_model(tcc_model.featurizer_net, False, False) # tcc_model.encoder_net.train() probe = Net( out_dim=5 ) # LinearProbe(in_dim=config.MODEL.EMBEDDER.EMBEDDING_SIZE, out_dim=5) probe.to(device).train() optimizer = torch.optim.Adam( [ {"params": probe.parameters(), "weight_decay": 0}, # {"params": tcc_model.parameters()}, ], lr=1e-3, ) # Figure out max batch size that's a multiple of the number of context # frames. This is so we can support large videos with many frames. lcm = 1 # tcc_model.num_ctx_frames max_batch_size = math.floor(32 / lcm) * lcm # Grab dataloader. train_loader = loaders["pretrain_train"] # [args.query] test_loader = loaders["pretrain_valid"] # [args.query] global_step = 0 complete = False try: while not complete: probe.train() # tcc_model.encoder_net.train() for batch_idx, batch in enumerate(train_loader): frames = batch["frames"].to(device) target = batch["debris_nums"].to(device) # Forward pass. loss, acc = embed_with_grad( None, # tcc_model, probe, frames, target, device, max_batch_size, args.l2_normalize, optimizer, ) print(f"{global_step}: Loss: {loss} - Acc: {100 * acc}") global_step += 1 # Get test accuracy. test_probe( test_loader, None, # tcc_model, probe, device, max_batch_size, args.l2_normalize, global_step, ) # Exit if complete. if global_step >= args.max_iters: complete = True break except KeyboardInterrupt: logging.info( "Caught keyboard interrupt. Saving model before quitting." ) finally: print("Exiting.")