epoch_train_loss_I = [] epoch_train_acc_I = [] epoch_val_loss_I = [] epoch_val_acc_I = [] max_val_acc_I = 0 epoch_train_loss_C = [] epoch_val_loss_C = [] epoch_train_loss_tot = [] epoch_val_loss_tot = [] for epoch in range(n_epochs): # TRAIN model_B.train() model_I.train() correct_B = 0 train_loss_B = 0 correct_I = 0 train_loss_I = 0 train_loss_C = 0 train_loss_tot = 0 train_num = 0 for i, (XI, XB, y) in enumerate(train_loader): XI, XB, y = XI.to(device), XB.to(device), y.long().to(device) if XI.size()[0] != batch_size: break
# In[8]: # 1st stage training: with recon_loss training_start=datetime.now() #split fit epoch_train_loss = [] epoch_train_acc = [] epoch_val_loss = [] epoch_val_acc = [] max_val_acc = 0 for epoch in range(n_epochs): # TRAIN model.train() correct = 0 train_loss = 0 train_num = 0 for i, (XI, XB, y) in enumerate(train_loader): if i >= len(train_loader)-removal: break if model.header == 'CNN': x = XI else: x = XB x, y = x.to(device), y.long().to(device) if x.size()[0] != batch_size: # print("batch {} size {} < {}, skip".format(i, x.size()[0], batch_size)) break