model = torch.load(config["resume_snapshot"], map_location=device) else: model = importlib.import_module(config["model"]).Model(config, device).to(device) criterion = nn.MSELoss() opt = O.Adam(model.parameters(), lr=config["optimizer"]["learning_rate"]) iterations = 0 start = time.time() best_valid_loss = -1 header = ' Time Epoch Iteration Progress (%Epoch) Loss' dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f}'.split(',')) log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f}'.split(',')) print(header) with experiment.train(): for epoch in range(config["training"]["epochs"]): for batch_idx, (X_batch, y_batch) in enumerate(training_generator): X_batch, y_batch = X_batch.to(device), y_batch.to(device) X_batch, y_batch = X_batch.permute(1, 0, 2), y_batch.permute(1, 0, 2) train_loss = train(X_batch, y_batch, model, opt, criterion, config["clip"]) experiment.log_metric("train_loss", train_loss, step=iterations) # checkpoint model periodically if iterations % config["every"]["save"] == 0: snapshot_prefix = os.path.join(config["result_directory"], 'snapshot') snapshot_path = snapshot_prefix + '_loss_{:.6f}_iter_{}_model.pt'.format(train_loss, iterations) torch.save({ 'model': model.state_dict(), 'opt': opt.state_dict(), }, snapshot_path)
def main(args): print('Pretrain? ', not args.not_pretrain) print(args.model) start_time = time.time() if opt['local_comet_dir']: comet_exp = OfflineExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo", project_name="selfcifar", workspace="cinjon", auto_metric_logging=True, auto_output_logging=None, auto_param_logging=False, offline_directory=opt['local_comet_dir']) else: comet_exp = CometExperiment(api_key="hIXq6lDzWzz24zgKv7RYz6blo", project_name="selfcifar", workspace="cinjon", auto_metric_logging=True, auto_output_logging=None, auto_param_logging=False) comet_exp.log_parameters(vars(args)) comet_exp.set_name(args.name) # Build model # path = "/misc/kcgscratch1/ChoGroup/resnick/spaceofmotion/zeping/bsn" linear_cls = NonLinearModel if args.do_nonlinear else LinearModel if args.model == "amdim": hparams = load_hparams_from_tags_csv( '/checkpoint/cinjon/amdim/meta_tags.csv') # hparams = load_hparams_from_tags_csv(os.path.join(path, "meta_tags.csv")) model = AMDIMModel(hparams) if not args.not_pretrain: # _path = os.path.join(path, "_ckpt_epoch_434.ckpt") _path = '/checkpoint/cinjon/amdim/_ckpt_epoch_434.ckpt' model.load_state_dict(torch.load(_path)["state_dict"]) else: print("AMDIM not loading checkpoint") # Debug linear_model = linear_cls(AMDIM_OUTPUT_DIM, args.num_classes) elif args.model == "ccc": model = CCCModel(None) if not args.not_pretrain: # _path = os.path.join(path, "TimeCycleCkpt14.pth") _path = '/checkpoint/cinjon/spaceofmotion/bsn/TimeCycleCkpt14.pth' checkpoint = torch.load(_path) base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items()) } model.load_state_dict(base_dict) else: print("CCC not loading checkpoint") # Debug linear_model = linaer_cls(CCC_OUTPUT_DIM, args.num_classes) #.to(device) elif args.model == "corrflow": model = CORRFLOWModel(None) if not args.not_pretrain: _path = '/checkpoint/cinjon/spaceofmotion/supercons/corrflow.kineticsmodel.pth' # _path = os.path.join(path, "corrflow.kineticsmodel.pth") checkpoint = torch.load(_path) base_dict = { '.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items()) } model.load_state_dict(base_dict) else: print("CorrFlow not loading checkpoing") # Debug linear_model = linear_cls(CORRFLOW_OUTPUT_DIM, args.num_classes) elif args.model == "resnet": if not args.not_pretrain: resnet = torchvision.models.resnet50(pretrained=True) else: resnet = torchvision.models.resnet50(pretrained=False) print("ResNet not loading checkpoint") # Debug modules = list(resnet.children())[:-1] model = nn.Sequential(*modules) linear_model = linear_cls(RESNET_OUTPUT_DIM, args.num_classes) else: raise Exception("model type has to be amdim, ccc, corrflow or resnet") if torch.cuda.device_count() > 1: model = nn.DataParallel(model).to(device) linear_model = nn.DataParallel(linear_model).to(device) else: model = model.to(device) linear_model = linear_model.to(device) # model = model.to(device) # linear_model = linear_model.to(device) # Freeze model for p in model.parameters(): p.requires_grad = False model.eval() if args.optimizer == "Adam": optimizer = optim.Adam(linear_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) print("Optimizer: Adam with weight decay: {}".format( args.weight_decay)) elif args.optimizer == "SGD": optimizer = optim.SGD(linear_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) print("Optimizer: SGD with weight decay: {} momentum: {}".format( args.weight_decay, args.momentum)) else: raise Exception("optimizer should be Adam or SGD") optimizer.zero_grad() # Set up log dir now = datetime.datetime.now() log_dir = '/checkpoint/cinjon/spaceofmotion/bsn/cifar-%d-weights/%s/%s' % ( args.num_classes, args.model, args.name) # log_dir = "{}{:%Y%m%dT%H%M}".format(args.model, now) # log_dir = os.path.join("weights", log_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) print("Saving to {}".format(log_dir)) batch_size = args.batch_size * torch.cuda.device_count() # CIFAR-10 if args.num_classes == 10: data_path = ("/private/home/cinjon/cifar-data/cifar-10-batches-py") _train_dataset = CIFAR_dataset(glob(os.path.join(data_path, "data*")), args.num_classes, args.model, True) # _train_acc_dataset = CIFAR_dataset( # glob(os.path.join(data_path, "data*")), # args.num_classes, # args.model, # False) train_dataloader = data.DataLoader(_train_dataset, shuffle=True, batch_size=batch_size, num_workers=args.num_workers) # train_split = int(len(_train_dataset) * 0.8) # train_dev_split = int(len(_train_dataset) - train_split) # train_dataset, train_dev_dataset = data.random_split( # _train_dataset, [train_split, train_dev_split]) # train_acc_dataloader = data.DataLoader( # train_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # train_dev_acc_dataloader = data.DataLoader( # train_dev_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # train_dataset = data.Subset(_train_dataset, list(range(train_split))) # train_dataloader = data.DataLoader( # train_dataset, shuffle=True, batch_size=batch_size, num_workers=args.num_workers) # train_acc_dataset = data.Subset( # _train_acc_dataset, list(range(train_split))) # train_acc_dataloader = data.DataLoader( # train_acc_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # train_dev_acc_dataset = data.Subset( # _train_acc_dataset, list(range(train_split, len(_train_acc_dataset)))) # train_dev_acc_dataloader = data.DataLoader( # train_dev_acc_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) _val_dataset = CIFAR_dataset([os.path.join(data_path, "test_batch")], args.num_classes, args.model, False) val_dataloader = data.DataLoader(_val_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # val_split = int(len(_val_dataset) * 0.8) # val_dev_split = int(len(_val_dataset) - val_split) # val_dataset, val_dev_dataset = data.random_split( # _val_dataset, [val_split, val_dev_split]) # val_dataloader = data.DataLoader( # val_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # val_dev_dataloader = data.DataLoader( # val_dev_dataset, shuffle=False, batch_size=batch_size, num_workers=args.num_workers) # CIFAR-100 elif args.num_classes == 100: data_path = ("/private/home/cinjon/cifar-data/cifar-100-python") _train_dataset = CIFAR_dataset([os.path.join(data_path, "train")], args.num_classes, args.model, True) train_dataloader = data.DataLoader(_train_dataset, shuffle=True, batch_size=batch_size) _val_dataset = CIFAR_dataset([os.path.join(data_path, "test")], args.num_classes, args.model, False) val_dataloader = data.DataLoader(_val_dataset, shuffle=False, batch_size=batch_size) else: raise Exception("num_classes should be 10 or 100") best_acc = 0.0 best_epoch = 0 # Training for epoch in range(1, args.epochs + 1): current_lr = max(3e-4, args.lr *\ math.pow(0.5, math.floor(epoch / args.lr_interval))) linear_model.train() if args.optimizer == "Adam": optimizer = optim.Adam(linear_model.parameters(), lr=current_lr, weight_decay=args.weight_decay) elif args.optimizer == "SGD": optimizer = optim.SGD( linear_model.parameters(), lr=current_lr, momentum=args.momentum, weight_decay=args.weight_decay, ) #################################################### # Train t = time.time() train_acc = 0 train_loss_sum = 0.0 for iter, input in enumerate(train_dataloader): if time.time( ) - start_time > args.time * 3600 - 300 and comet_exp is not None: comet_exp.end() sys.exit(-1) imgs = input[0].to(device) if args.model != "resnet": imgs = imgs.unsqueeze(1) lbls = input[1].flatten().to(device) # output = model(imgs) # output = linear_model(output) output = linear_model(model(imgs)) loss = F.cross_entropy(output, lbls) train_loss_sum += float(loss.data) train_acc += int(sum(torch.argmax(output, dim=1) == lbls)) optimizer.zero_grad() loss.backward() optimizer.step() # log_text = "train epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter" if iter % 1500 == 0: log_text = "train epoch {}/{}\titer {}/{} loss:{}" print(log_text.format(epoch, args.epochs, iter + 1, len(train_dataloader), loss.data, time.time() - t), flush=False) t = time.time() train_acc /= len(_train_dataset) train_loss_sum /= len(train_dataloader) with comet_exp.train(): comet_exp.log_metrics({ 'acc': train_acc, 'loss': train_loss_sum }, step=(epoch + 1) * len(train_dataloader), epoch=epoch + 1) print("train acc epoch {}/{} loss:{} train_acc:{}".format( epoch, args.epochs, train_loss_sum, train_acc), flush=True) ####################################################################### # Train acc # linear_model.eval() # train_acc = 0 # train_loss_sum = 0.0 # for iter, input in enumerate(train_acc_dataloader): # imgs = input[0].to(device) # if args.model != "resnet": # imgs = imgs.unsqueeze(1) # lbls = input[1].flatten().to(device) # # # output = model(imgs) # # output = linear_model(output) # output = linear_model(model(imgs)) # loss = F.cross_entropy(output, lbls) # train_loss_sum += float(loss.data) # train_acc += int(sum(torch.argmax(output, dim=1) == lbls)) # # print("train acc epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter".format( # epoch, # args.epochs, # iter+1, # len(train_acc_dataloader), # loss.data, # time.time() - t), # flush=True) # t = time.time() # # # train_acc /= len(train_acc_dataset) # train_loss_sum /= len(train_acc_dataloader) # print("train acc epoch {}/{} loss:{} train_acc:{}".format( # epoch, args.epochs, train_loss_sum, train_acc), flush=True) ####################################################################### # Train dev acc # # linear_model.eval() # train_dev_acc = 0 # train_dev_loss_sum = 0.0 # for iter, input in enumerate(train_dev_acc_dataloader): # imgs = input[0].to(device) # if args.model != "resnet": # imgs = imgs.unsqueeze(1) # lbls = input[1].flatten().to(device) # # output = model(imgs) # output = linear_model(output) # # output = linear_model(model(imgs)) # loss = F.cross_entropy(output, lbls) # train_dev_loss_sum += float(loss.data) # train_dev_acc += int(sum(torch.argmax(output, dim=1) == lbls)) # # print("train dev acc epoch {}/{}\titer {}/{} loss:{} {:.3f}s/iter".format( # epoch, # args.epochs, # iter+1, # len(train_dev_acc_dataloader), # loss.data, # time.time() - t), # flush=True) # t = time.time() # # train_dev_acc /= len(train_dev_acc_dataset) # train_dev_loss_sum /= len(train_dev_acc_dataloader) # print("train dev epoch {}/{} loss:{} train_dev_acc:{}".format( # epoch, args.epochs, train_dev_loss_sum, train_dev_acc), flush=True) ####################################################################### # Val dev # # linear_model.eval() # val_dev_acc = 0 # val_dev_loss_sum = 0.0 # for iter, input in enumerate(val_dev_dataloader): # imgs = input[0].to(device) # if args.model != "resnet": # imgs = imgs.unsqueeze(1) # lbls = input[1].flatten().to(device) # # output = model(imgs) # output = linear_model(output) # loss = F.cross_entropy(output, lbls) # val_dev_loss_sum += float(loss.data) # val_dev_acc += int(sum(torch.argmax(output, dim=1) == lbls)) # # print("val dev epoch {}/{} iter {}/{} loss:{} {:.3f}s/iter".format( # epoch, # args.epochs, # iter+1, # len(val_dev_dataloader), # loss.data, # time.time() - t), # flush=True) # t = time.time() # # val_dev_acc /= len(val_dev_dataset) # val_dev_loss_sum /= len(val_dev_dataloader) # print("val dev epoch {}/{} loss:{} val_dev_acc:{}".format( # epoch, args.epochs, val_dev_loss_sum, val_dev_acc), flush=True) ####################################################################### # Val linear_model.eval() val_acc = 0 val_loss_sum = 0.0 for iter, input in enumerate(val_dataloader): if time.time( ) - start_time > args.time * 3600 - 300 and comet_exp is not None: comet_exp.end() sys.exit(-1) imgs = input[0].to(device) if args.model != "resnet": imgs = imgs.unsqueeze(1) lbls = input[1].flatten().to(device) output = model(imgs) output = linear_model(output) loss = F.cross_entropy(output, lbls) val_loss_sum += float(loss.data) val_acc += int(sum(torch.argmax(output, dim=1) == lbls)) # log_text = "val epoch {}/{} iter {}/{} loss:{} {:.3f}s/iter" if iter % 1500 == 0: log_text = "val epoch {}/{} iter {}/{} loss:{}" print(log_text.format(epoch, args.epochs, iter + 1, len(val_dataloader), loss.data, time.time() - t), flush=False) t = time.time() val_acc /= len(_val_dataset) val_loss_sum /= len(val_dataloader) print("val epoch {}/{} loss:{} val_acc:{}".format( epoch, args.epochs, val_loss_sum, val_acc)) with comet_exp.test(): comet_exp.log_metrics({ 'acc': val_acc, 'loss': val_loss_sum }, step=(epoch + 1) * len(train_dataloader), epoch=epoch + 1) if val_acc > best_acc: best_acc = val_acc best_epoch = epoch linear_save_path = os.path.join(log_dir, "{}.linear.pth".format(epoch)) model_save_path = os.path.join(log_dir, "{}.model.pth".format(epoch)) torch.save(linear_model.state_dict(), linear_save_path) torch.save(model.state_dict(), model_save_path) # Check bias and variance print( "Epoch {} lr {} total: train_loss:{} train_acc:{} val_loss:{} val_acc:{}" .format(epoch, current_lr, train_loss_sum, train_acc, val_loss_sum, val_acc), flush=True) # print("Epoch {} lr {} total: train_acc:{} train_dev_acc:{} val_dev_acc:{} val_acc:{}".format( # epoch, current_lr, train_acc, train_dev_acc, val_dev_acc, val_acc), flush=True) print("The best epoch: {} acc: {}".format(best_epoch, best_acc))